keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import re
|
|
2
3
|
import string
|
|
3
4
|
|
|
@@ -5,6 +6,7 @@ import ml_dtypes
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
|
|
7
8
|
from keras.src import activations
|
|
9
|
+
from keras.src import backend
|
|
8
10
|
from keras.src import constraints
|
|
9
11
|
from keras.src import dtype_policies
|
|
10
12
|
from keras.src import initializers
|
|
@@ -14,6 +16,9 @@ from keras.src import regularizers
|
|
|
14
16
|
from keras.src.api_export import keras_export
|
|
15
17
|
from keras.src.layers.input_spec import InputSpec
|
|
16
18
|
from keras.src.layers.layer import Layer
|
|
19
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
20
|
+
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
21
|
+
from keras.src.saving import serialization_lib
|
|
17
22
|
|
|
18
23
|
|
|
19
24
|
@keras_export("keras.layers.EinsumDense")
|
|
@@ -131,6 +136,8 @@ class EinsumDense(Layer):
|
|
|
131
136
|
bias_constraint=None,
|
|
132
137
|
lora_rank=None,
|
|
133
138
|
lora_alpha=None,
|
|
139
|
+
gptq_unpacked_column_size=None,
|
|
140
|
+
quantization_config=None,
|
|
134
141
|
**kwargs,
|
|
135
142
|
):
|
|
136
143
|
super().__init__(**kwargs)
|
|
@@ -150,6 +157,8 @@ class EinsumDense(Layer):
|
|
|
150
157
|
self.lora_rank = lora_rank
|
|
151
158
|
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
|
|
152
159
|
self.lora_enabled = False
|
|
160
|
+
self.gptq_unpacked_column_size = gptq_unpacked_column_size
|
|
161
|
+
self.quantization_config = quantization_config
|
|
153
162
|
|
|
154
163
|
def build(self, input_shape):
|
|
155
164
|
shape_data = _analyze_einsum_string(
|
|
@@ -162,12 +171,16 @@ class EinsumDense(Layer):
|
|
|
162
171
|
self.full_output_shape = tuple(full_output_shape)
|
|
163
172
|
self.input_spec = InputSpec(ndim=len(input_shape))
|
|
164
173
|
if self.quantization_mode is not None:
|
|
165
|
-
self.quantized_build(
|
|
174
|
+
self.quantized_build(
|
|
175
|
+
kernel_shape,
|
|
176
|
+
mode=self.quantization_mode,
|
|
177
|
+
config=self.quantization_config,
|
|
178
|
+
)
|
|
166
179
|
# Skip creating a duplicate kernel variable when the layer is already
|
|
167
180
|
# quantized to int8 or int4, because `quantized_build` has created the
|
|
168
181
|
# appropriate kernel variable. For other modes (e.g., float8 or no
|
|
169
182
|
# quantization), we still need the floating-point kernel.
|
|
170
|
-
if self.quantization_mode not in ("int8", "int4"):
|
|
183
|
+
if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
|
|
171
184
|
# If the layer is quantized to int8, `self._kernel` will be added
|
|
172
185
|
# in `self._int8_build`. Therefore, we skip it here.
|
|
173
186
|
self._kernel = self.add_weight(
|
|
@@ -197,15 +210,62 @@ class EinsumDense(Layer):
|
|
|
197
210
|
|
|
198
211
|
@property
|
|
199
212
|
def kernel(self):
|
|
213
|
+
from keras.src.quantizers import gptq_core
|
|
214
|
+
|
|
200
215
|
if not self.built:
|
|
201
216
|
raise AttributeError(
|
|
202
217
|
"You must build the layer before accessing `kernel`."
|
|
203
218
|
)
|
|
219
|
+
|
|
220
|
+
mode = self.quantization_mode
|
|
221
|
+
is_gptq = mode == "gptq"
|
|
222
|
+
is_awq = mode == "awq"
|
|
223
|
+
is_int4 = mode == "int4"
|
|
224
|
+
gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
|
|
225
|
+
awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
|
|
226
|
+
gptq_bits = (
|
|
227
|
+
gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Decide the source tensor first (packed vs already-quantized vs plain
|
|
231
|
+
# kernel)
|
|
232
|
+
if is_gptq and gptq_calibrated and gptq_bits != 4:
|
|
233
|
+
# calibrated GPTQ, not 4-bit, no unpacking needed
|
|
234
|
+
kernel = self.quantized_kernel
|
|
235
|
+
else:
|
|
236
|
+
# Start with the stored kernel
|
|
237
|
+
kernel = getattr(self, "_kernel", None)
|
|
238
|
+
|
|
239
|
+
# Handle int4 unpacking cases in one place
|
|
240
|
+
if is_int4:
|
|
241
|
+
kernel = quantizers.unpack_int4(
|
|
242
|
+
kernel,
|
|
243
|
+
self._orig_length_along_pack_axis,
|
|
244
|
+
self._int4_pack_axis,
|
|
245
|
+
)
|
|
246
|
+
elif is_gptq and gptq_calibrated and gptq_bits == 4:
|
|
247
|
+
kernel = quantizers.unpack_int4(
|
|
248
|
+
self.quantized_kernel,
|
|
249
|
+
orig_len=self.gptq_unpacked_column_size,
|
|
250
|
+
axis=0,
|
|
251
|
+
dtype="uint8",
|
|
252
|
+
)
|
|
253
|
+
elif is_awq and awq_calibrated:
|
|
254
|
+
# AWQ always uses 4-bit quantization
|
|
255
|
+
kernel = quantizers.unpack_int4(
|
|
256
|
+
self.quantized_kernel,
|
|
257
|
+
orig_len=self.awq_unpacked_column_size,
|
|
258
|
+
axis=0,
|
|
259
|
+
dtype="uint8",
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Apply LoRA if enabled
|
|
204
263
|
if self.lora_enabled:
|
|
205
|
-
|
|
206
|
-
self.
|
|
207
|
-
)
|
|
208
|
-
|
|
264
|
+
kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
|
|
265
|
+
self.lora_kernel_a, self.lora_kernel_b
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return kernel
|
|
209
269
|
|
|
210
270
|
def compute_output_shape(self, _):
|
|
211
271
|
return self.full_output_shape
|
|
@@ -239,6 +299,10 @@ class EinsumDense(Layer):
|
|
|
239
299
|
raise ValueError(
|
|
240
300
|
"lora is already enabled. This can only be done once per layer."
|
|
241
301
|
)
|
|
302
|
+
if self.quantization_mode == "gptq":
|
|
303
|
+
raise NotImplementedError(
|
|
304
|
+
"lora is not currently supported with GPTQ quantization."
|
|
305
|
+
)
|
|
242
306
|
self._tracker.unlock()
|
|
243
307
|
# Determine the appropriate (unpacked) kernel shape for LoRA.
|
|
244
308
|
if self.quantization_mode == "int4":
|
|
@@ -277,26 +341,26 @@ class EinsumDense(Layer):
|
|
|
277
341
|
# Do nothing if the layer isn't yet built
|
|
278
342
|
if not self.built:
|
|
279
343
|
return
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
344
|
+
mode = self.quantization_mode
|
|
345
|
+
if mode not in self.variable_serialization_spec:
|
|
346
|
+
raise self._quantization_mode_error(mode)
|
|
347
|
+
|
|
348
|
+
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
|
|
349
|
+
# for None/gptq)
|
|
350
|
+
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
|
|
351
|
+
idx = 0
|
|
352
|
+
for name in self.variable_serialization_spec[mode]:
|
|
353
|
+
if name == "kernel":
|
|
354
|
+
store[str(idx)] = kernel_value
|
|
355
|
+
elif name == "bias" and self.bias is None:
|
|
356
|
+
continue
|
|
357
|
+
elif name == "kernel_scale" and mode in ("int4", "int8"):
|
|
358
|
+
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
359
|
+
# `_get_kernel_with_merged_lora()`
|
|
360
|
+
store[str(idx)] = merged_kernel_scale
|
|
296
361
|
else:
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
store[str(i)] = variable
|
|
362
|
+
store[str(idx)] = getattr(self, name)
|
|
363
|
+
idx += 1
|
|
300
364
|
|
|
301
365
|
def load_own_variables(self, store):
|
|
302
366
|
if not self.lora_enabled:
|
|
@@ -304,25 +368,23 @@ class EinsumDense(Layer):
|
|
|
304
368
|
# Do nothing if the layer isn't yet built
|
|
305
369
|
if not self.built:
|
|
306
370
|
return
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
target_variables.append(self.outputs_grad_amax_history)
|
|
371
|
+
mode = self.quantization_mode
|
|
372
|
+
if mode not in self.variable_serialization_spec:
|
|
373
|
+
raise self._quantization_mode_error(mode)
|
|
374
|
+
|
|
375
|
+
# A saved GPTQ/AWQ quantized model will always be calibrated.
|
|
376
|
+
self.is_gptq_calibrated = mode == "gptq"
|
|
377
|
+
self.is_awq_calibrated = mode == "awq"
|
|
378
|
+
|
|
379
|
+
idx = 0
|
|
380
|
+
for name in self.variable_serialization_spec[mode]:
|
|
381
|
+
if name == "kernel":
|
|
382
|
+
self._kernel.assign(store[str(idx)])
|
|
383
|
+
elif name == "bias" and self.bias is None:
|
|
384
|
+
continue
|
|
322
385
|
else:
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
variable.assign(store[str(i)])
|
|
386
|
+
getattr(self, name).assign(store[str(idx)])
|
|
387
|
+
idx += 1
|
|
326
388
|
if self.lora_enabled:
|
|
327
389
|
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
|
|
328
390
|
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
|
|
@@ -347,64 +409,103 @@ class EinsumDense(Layer):
|
|
|
347
409
|
),
|
|
348
410
|
"kernel_constraint": constraints.serialize(self.kernel_constraint),
|
|
349
411
|
"bias_constraint": constraints.serialize(self.bias_constraint),
|
|
412
|
+
"quantization_config": serialization_lib.serialize_keras_object(
|
|
413
|
+
self.quantization_config
|
|
414
|
+
),
|
|
350
415
|
}
|
|
351
416
|
if self.lora_rank:
|
|
352
417
|
config["lora_rank"] = self.lora_rank
|
|
353
418
|
config["lora_alpha"] = self.lora_alpha
|
|
419
|
+
if self.gptq_unpacked_column_size:
|
|
420
|
+
config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size
|
|
354
421
|
return {**base_config, **config}
|
|
355
422
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
"and thus it doesn't have any variables. "
|
|
363
|
-
f"However the weights file lists {len(store.keys())} "
|
|
364
|
-
"variables for this layer.\n"
|
|
365
|
-
"In most cases, this error indicates that either:\n\n"
|
|
366
|
-
"1. The layer is owned by a parent layer that "
|
|
367
|
-
"implements a `build()` method, but calling the "
|
|
368
|
-
"parent's `build()` method did NOT create the state of "
|
|
369
|
-
f"the child layer '{self.name}'. A `build()` method "
|
|
370
|
-
"must create ALL state for the layer, including "
|
|
371
|
-
"the state of any children layers.\n\n"
|
|
372
|
-
"2. You need to implement "
|
|
373
|
-
"the `def build_from_config(self, config)` method "
|
|
374
|
-
f"on layer '{self.name}', to specify how to rebuild "
|
|
375
|
-
"it during loading. "
|
|
376
|
-
"In this case, you might also want to implement the "
|
|
377
|
-
"method that generates the build config at saving time, "
|
|
378
|
-
"`def get_build_config(self)`. "
|
|
379
|
-
"The method `build_from_config()` is meant "
|
|
380
|
-
"to create the state "
|
|
381
|
-
"of the layer (i.e. its variables) upon deserialization.",
|
|
382
|
-
)
|
|
383
|
-
raise ValueError(
|
|
384
|
-
f"Layer '{self.name}' expected {len(all_vars)} variables, "
|
|
385
|
-
"but received "
|
|
386
|
-
f"{len(store.keys())} variables during loading. "
|
|
387
|
-
f"Expected: {[v.name for v in all_vars]}"
|
|
423
|
+
@classmethod
|
|
424
|
+
def from_config(cls, config):
|
|
425
|
+
config = config.copy()
|
|
426
|
+
config["quantization_config"] = (
|
|
427
|
+
serialization_lib.deserialize_keras_object(
|
|
428
|
+
config.get("quantization_config", None)
|
|
388
429
|
)
|
|
430
|
+
)
|
|
431
|
+
return super().from_config(config)
|
|
432
|
+
|
|
433
|
+
@property
|
|
434
|
+
def variable_serialization_spec(self):
|
|
435
|
+
"""Returns a dict mapping quantization modes to variable names in order.
|
|
389
436
|
|
|
390
|
-
|
|
437
|
+
This spec is used by `save_own_variables` and `load_own_variables` to
|
|
438
|
+
determine the correct ordering of variables during serialization for
|
|
439
|
+
each quantization mode. `None` means no quantization.
|
|
440
|
+
"""
|
|
441
|
+
return {
|
|
442
|
+
None: [
|
|
443
|
+
"kernel",
|
|
444
|
+
"bias",
|
|
445
|
+
],
|
|
446
|
+
"int8": [
|
|
447
|
+
"kernel",
|
|
448
|
+
"bias",
|
|
449
|
+
"kernel_scale",
|
|
450
|
+
],
|
|
451
|
+
"int4": [
|
|
452
|
+
"kernel",
|
|
453
|
+
"bias",
|
|
454
|
+
"kernel_scale",
|
|
455
|
+
],
|
|
456
|
+
"float8": [
|
|
457
|
+
"kernel",
|
|
458
|
+
"bias",
|
|
459
|
+
"inputs_scale",
|
|
460
|
+
"inputs_amax_history",
|
|
461
|
+
"kernel_scale",
|
|
462
|
+
"kernel_amax_history",
|
|
463
|
+
"outputs_grad_scale",
|
|
464
|
+
"outputs_grad_amax_history",
|
|
465
|
+
],
|
|
466
|
+
"gptq": [
|
|
467
|
+
"bias",
|
|
468
|
+
"quantized_kernel",
|
|
469
|
+
"kernel_scale",
|
|
470
|
+
"kernel_zero",
|
|
471
|
+
"g_idx",
|
|
472
|
+
],
|
|
473
|
+
"awq": [
|
|
474
|
+
"bias",
|
|
475
|
+
"quantized_kernel",
|
|
476
|
+
"kernel_scale",
|
|
477
|
+
"kernel_zero",
|
|
478
|
+
"awq_scales",
|
|
479
|
+
"g_idx",
|
|
480
|
+
],
|
|
481
|
+
}
|
|
391
482
|
|
|
392
|
-
def quantized_build(self, kernel_shape, mode):
|
|
483
|
+
def quantized_build(self, kernel_shape, mode, config=None):
|
|
393
484
|
if mode == "int8":
|
|
394
|
-
self._int8_build(kernel_shape)
|
|
485
|
+
self._int8_build(kernel_shape, config)
|
|
395
486
|
elif mode == "int4":
|
|
396
|
-
self._int4_build(kernel_shape)
|
|
487
|
+
self._int4_build(kernel_shape, config)
|
|
397
488
|
elif mode == "float8":
|
|
398
489
|
self._float8_build()
|
|
490
|
+
elif mode == "gptq":
|
|
491
|
+
self._gptq_build(kernel_shape, config)
|
|
492
|
+
elif mode == "awq":
|
|
493
|
+
self._awq_build(kernel_shape, config)
|
|
399
494
|
else:
|
|
400
495
|
raise self._quantization_mode_error(mode)
|
|
401
496
|
self._is_quantized = True
|
|
402
497
|
|
|
403
|
-
def _int8_build(self, kernel_shape):
|
|
498
|
+
def _int8_build(self, kernel_shape, config=None):
|
|
404
499
|
self._set_quantization_info()
|
|
405
|
-
self.inputs_quantizer =
|
|
406
|
-
|
|
500
|
+
self.inputs_quantizer = (
|
|
501
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
502
|
+
config,
|
|
503
|
+
quantizers.AbsMaxQuantizer(),
|
|
504
|
+
)
|
|
407
505
|
)
|
|
506
|
+
# If the config provided a default AbsMaxQuantizer, we need to
|
|
507
|
+
# override the axis to match the equation's reduction axes.
|
|
508
|
+
self.quantization_axis = tuple(self._input_reduced_axes)
|
|
408
509
|
self._kernel = self.add_weight(
|
|
409
510
|
name="kernel",
|
|
410
511
|
shape=kernel_shape,
|
|
@@ -420,7 +521,244 @@ class EinsumDense(Layer):
|
|
|
420
521
|
trainable=False,
|
|
421
522
|
)
|
|
422
523
|
|
|
423
|
-
def
|
|
524
|
+
def _gptq_build(self, kernel_shape, config):
|
|
525
|
+
"""
|
|
526
|
+
Allocate quantized kernel & params for EinsumDense.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
kernel_shape: tuple/list; the layer's original kernel shape, e.g.
|
|
530
|
+
[in_features, out_features] or [in_features, heads, head_dim].
|
|
531
|
+
group_size: int; contiguous input-group size for quantization
|
|
532
|
+
(=-1 means per-output-channel with no grouping).
|
|
533
|
+
"""
|
|
534
|
+
from keras.src.quantizers import gptq_core
|
|
535
|
+
|
|
536
|
+
# Ensures the forward pass uses the original high-precision kernel
|
|
537
|
+
# until calibration has been performed.
|
|
538
|
+
self.is_gptq_calibrated = False
|
|
539
|
+
|
|
540
|
+
self.original_kernel_shape = kernel_shape
|
|
541
|
+
if len(kernel_shape) == 2:
|
|
542
|
+
rows = kernel_shape[0]
|
|
543
|
+
columns = kernel_shape[1]
|
|
544
|
+
elif len(kernel_shape) == 3:
|
|
545
|
+
shape = list(self.original_kernel_shape)
|
|
546
|
+
d_model_dim_index = shape.index(max(shape))
|
|
547
|
+
|
|
548
|
+
if d_model_dim_index == 0: # QKV projection case
|
|
549
|
+
in_features, heads, head_dim = shape
|
|
550
|
+
rows, columns = (
|
|
551
|
+
in_features,
|
|
552
|
+
heads * head_dim,
|
|
553
|
+
)
|
|
554
|
+
elif d_model_dim_index in [1, 2]: # Attention Output case
|
|
555
|
+
heads, head_dim, out_features = shape
|
|
556
|
+
rows, columns = (
|
|
557
|
+
heads * head_dim,
|
|
558
|
+
out_features,
|
|
559
|
+
)
|
|
560
|
+
else:
|
|
561
|
+
raise ValueError("Could not determine row/column split.")
|
|
562
|
+
|
|
563
|
+
group_size = gptq_core.get_group_size_for_layer(self, config)
|
|
564
|
+
n_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
|
|
565
|
+
|
|
566
|
+
self.gptq_unpacked_column_size = columns
|
|
567
|
+
|
|
568
|
+
weight_bits = gptq_core.get_weight_bits_for_layer(self, config)
|
|
569
|
+
# For 4-bit weights, we pack two values per byte.
|
|
570
|
+
kernel_columns = (columns + 1) // 2 if weight_bits == 4 else columns
|
|
571
|
+
|
|
572
|
+
self._set_quantization_info()
|
|
573
|
+
|
|
574
|
+
self.quantized_kernel = self.add_weight(
|
|
575
|
+
name="kernel",
|
|
576
|
+
shape=(kernel_columns, rows),
|
|
577
|
+
initializer="zeros",
|
|
578
|
+
dtype="uint8",
|
|
579
|
+
trainable=False,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
self.kernel_scale = self.add_weight(
|
|
583
|
+
name="kernel_scale",
|
|
584
|
+
shape=(columns, n_groups),
|
|
585
|
+
initializer="ones",
|
|
586
|
+
trainable=False,
|
|
587
|
+
)
|
|
588
|
+
self.kernel_zero = self.add_weight(
|
|
589
|
+
name="zero_point",
|
|
590
|
+
shape=(columns, n_groups),
|
|
591
|
+
initializer="zeros",
|
|
592
|
+
dtype="uint8",
|
|
593
|
+
trainable=False,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
self.g_idx = self.add_weight(
|
|
597
|
+
name="g_idx",
|
|
598
|
+
shape=(rows,),
|
|
599
|
+
initializer="zeros",
|
|
600
|
+
dtype="float32",
|
|
601
|
+
trainable=False,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
def _gptq_call(self, inputs, training=False):
|
|
605
|
+
from keras.src.quantizers import gptq_core
|
|
606
|
+
|
|
607
|
+
if not self.is_gptq_calibrated:
|
|
608
|
+
W = self._kernel
|
|
609
|
+
else:
|
|
610
|
+
should_unpack = (
|
|
611
|
+
gptq_core.get_weight_bits_for_layer(self, config=None) == 4
|
|
612
|
+
)
|
|
613
|
+
W = (
|
|
614
|
+
quantizers.unpack_int4(
|
|
615
|
+
self.quantized_kernel,
|
|
616
|
+
orig_len=self.gptq_unpacked_column_size,
|
|
617
|
+
axis=0,
|
|
618
|
+
dtype="uint8",
|
|
619
|
+
)
|
|
620
|
+
if should_unpack
|
|
621
|
+
else self.quantized_kernel
|
|
622
|
+
)
|
|
623
|
+
W = dequantize_with_sz_map(
|
|
624
|
+
W,
|
|
625
|
+
self.kernel_scale,
|
|
626
|
+
self.kernel_zero,
|
|
627
|
+
self.g_idx,
|
|
628
|
+
)
|
|
629
|
+
W = ops.transpose(W)
|
|
630
|
+
|
|
631
|
+
W = ops.reshape(W, self.original_kernel_shape)
|
|
632
|
+
|
|
633
|
+
y = ops.einsum(self.equation, inputs, W)
|
|
634
|
+
if self.bias is not None:
|
|
635
|
+
y = ops.add(y, self.bias)
|
|
636
|
+
if self.activation is not None:
|
|
637
|
+
y = self.activation(y)
|
|
638
|
+
return y
|
|
639
|
+
|
|
640
|
+
def _awq_build(self, kernel_shape, config):
|
|
641
|
+
"""Build variables for AWQ quantization.
|
|
642
|
+
|
|
643
|
+
AWQ uses 4-bit quantization with per-channel AWQ scales that protect
|
|
644
|
+
salient weights based on activation magnitudes.
|
|
645
|
+
"""
|
|
646
|
+
from keras.src.quantizers import awq_core
|
|
647
|
+
|
|
648
|
+
# Ensures the forward pass uses the original high-precision kernel
|
|
649
|
+
# until calibration has been performed.
|
|
650
|
+
self.is_awq_calibrated = False
|
|
651
|
+
|
|
652
|
+
self.original_kernel_shape = kernel_shape
|
|
653
|
+
if len(kernel_shape) == 2:
|
|
654
|
+
rows = kernel_shape[0]
|
|
655
|
+
columns = kernel_shape[1]
|
|
656
|
+
elif len(kernel_shape) == 3:
|
|
657
|
+
shape = list(self.original_kernel_shape)
|
|
658
|
+
d_model_dim_index = shape.index(max(shape))
|
|
659
|
+
|
|
660
|
+
if d_model_dim_index == 0: # QKV projection case
|
|
661
|
+
in_features, heads, head_dim = shape
|
|
662
|
+
rows, columns = (
|
|
663
|
+
in_features,
|
|
664
|
+
heads * head_dim,
|
|
665
|
+
)
|
|
666
|
+
elif d_model_dim_index in [1, 2]: # Attention Output case
|
|
667
|
+
heads, head_dim, out_features = shape
|
|
668
|
+
rows, columns = (
|
|
669
|
+
heads * head_dim,
|
|
670
|
+
out_features,
|
|
671
|
+
)
|
|
672
|
+
else:
|
|
673
|
+
raise ValueError("Could not determine row/column split.")
|
|
674
|
+
else:
|
|
675
|
+
raise ValueError("AWQ quantization only supports 2D or 3D kernels.")
|
|
676
|
+
|
|
677
|
+
group_size = awq_core.get_group_size_for_layer(self, config)
|
|
678
|
+
num_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
|
|
679
|
+
|
|
680
|
+
self.awq_unpacked_column_size = columns
|
|
681
|
+
|
|
682
|
+
# For 4-bit weights, we pack two values per byte.
|
|
683
|
+
kernel_columns = (columns + 1) // 2
|
|
684
|
+
|
|
685
|
+
self._set_quantization_info()
|
|
686
|
+
|
|
687
|
+
self.quantized_kernel = self.add_weight(
|
|
688
|
+
name="kernel",
|
|
689
|
+
shape=(kernel_columns, rows),
|
|
690
|
+
initializer="zeros",
|
|
691
|
+
dtype="uint8",
|
|
692
|
+
trainable=False,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
self.kernel_scale = self.add_weight(
|
|
696
|
+
name="kernel_scale",
|
|
697
|
+
shape=(columns, num_groups),
|
|
698
|
+
initializer="ones",
|
|
699
|
+
trainable=False,
|
|
700
|
+
)
|
|
701
|
+
self.kernel_zero = self.add_weight(
|
|
702
|
+
name="zero_point",
|
|
703
|
+
shape=(columns, num_groups),
|
|
704
|
+
initializer="zeros",
|
|
705
|
+
dtype="uint8",
|
|
706
|
+
trainable=False,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
# Per-channel AWQ scales from activation magnitudes
|
|
710
|
+
self.awq_scales = self.add_weight(
|
|
711
|
+
name="awq_scales",
|
|
712
|
+
shape=(rows,),
|
|
713
|
+
initializer="ones",
|
|
714
|
+
trainable=False,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
self.g_idx = self.add_weight(
|
|
718
|
+
name="g_idx",
|
|
719
|
+
shape=(rows,),
|
|
720
|
+
initializer="zeros",
|
|
721
|
+
dtype="float32",
|
|
722
|
+
trainable=False,
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
def _awq_call(self, inputs, training=False):
|
|
726
|
+
"""Forward pass for AWQ quantized layer."""
|
|
727
|
+
if not self.is_awq_calibrated:
|
|
728
|
+
W = self._kernel
|
|
729
|
+
else:
|
|
730
|
+
# Unpack 4-bit weights
|
|
731
|
+
W = quantizers.unpack_int4(
|
|
732
|
+
self.quantized_kernel,
|
|
733
|
+
orig_len=self.awq_unpacked_column_size,
|
|
734
|
+
axis=0,
|
|
735
|
+
dtype="uint8",
|
|
736
|
+
)
|
|
737
|
+
# Dequantize using scale/zero maps
|
|
738
|
+
W = dequantize_with_sz_map(
|
|
739
|
+
W,
|
|
740
|
+
self.kernel_scale,
|
|
741
|
+
self.kernel_zero,
|
|
742
|
+
self.g_idx,
|
|
743
|
+
)
|
|
744
|
+
W = ops.transpose(W)
|
|
745
|
+
|
|
746
|
+
# Apply AWQ scales by dividing to restore original magnitude
|
|
747
|
+
# (We multiplied by scales before quantization, so divide to undo)
|
|
748
|
+
# awq_scales has shape [input_dim], W has shape [input_dim, out_dim]
|
|
749
|
+
# Expand dims for proper broadcasting.
|
|
750
|
+
W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
|
|
751
|
+
|
|
752
|
+
W = ops.reshape(W, self.original_kernel_shape)
|
|
753
|
+
|
|
754
|
+
y = ops.einsum(self.equation, inputs, W)
|
|
755
|
+
if self.bias is not None:
|
|
756
|
+
y = ops.add(y, self.bias)
|
|
757
|
+
if self.activation is not None:
|
|
758
|
+
y = self.activation(y)
|
|
759
|
+
return y
|
|
760
|
+
|
|
761
|
+
def _int4_build(self, kernel_shape, config=None):
|
|
424
762
|
"""Build variables for int4 quantization.
|
|
425
763
|
|
|
426
764
|
The packed int4 kernel stores two int4 values within a single int8
|
|
@@ -432,9 +770,15 @@ class EinsumDense(Layer):
|
|
|
432
770
|
self._set_quantization_info()
|
|
433
771
|
|
|
434
772
|
# Quantizer for the inputs (per the reduced axes)
|
|
435
|
-
self.inputs_quantizer =
|
|
436
|
-
|
|
773
|
+
self.inputs_quantizer = (
|
|
774
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
775
|
+
config,
|
|
776
|
+
quantizers.AbsMaxQuantizer(),
|
|
777
|
+
)
|
|
437
778
|
)
|
|
779
|
+
# If the config provided a default AbsMaxQuantizer, we need to
|
|
780
|
+
# override the axis to match the equation's reduction axes.
|
|
781
|
+
self.quantization_axis = tuple(self._input_reduced_axes)
|
|
438
782
|
|
|
439
783
|
# Choose the axis to perform int4 packing - use the first reduced axis
|
|
440
784
|
# for the kernel (analogous to the input dimension of a Dense layer).
|
|
@@ -556,13 +900,36 @@ class EinsumDense(Layer):
|
|
|
556
900
|
)
|
|
557
901
|
return (inputs_grad, None, None)
|
|
558
902
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
903
|
+
if self.inputs_quantizer:
|
|
904
|
+
inputs, inputs_scale = self.inputs_quantizer(
|
|
905
|
+
inputs, axis=self.quantization_axis
|
|
906
|
+
)
|
|
907
|
+
# Align `inputs_scale` axes with the output
|
|
908
|
+
# for correct broadcasting
|
|
909
|
+
inputs_scale = self._adjust_scale_for_quant(
|
|
910
|
+
inputs_scale, "input"
|
|
911
|
+
)
|
|
912
|
+
x = ops.einsum(self.equation, inputs, kernel)
|
|
913
|
+
# De-scale outputs
|
|
914
|
+
x = ops.cast(x, self.compute_dtype)
|
|
915
|
+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
|
|
916
|
+
else:
|
|
917
|
+
# Weight-only quantization: dequantize kernel and use float
|
|
918
|
+
# einsum. This is a workaround for PyTorch's einsum which
|
|
919
|
+
# doesn't support mixed-precision inputs (float input,
|
|
920
|
+
# int8 kernel).
|
|
921
|
+
if backend.backend() == "torch":
|
|
922
|
+
kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
|
|
923
|
+
float_kernel = ops.divide(
|
|
924
|
+
ops.cast(kernel, dtype=self.compute_dtype),
|
|
925
|
+
kernel_scale,
|
|
926
|
+
)
|
|
927
|
+
x = ops.einsum(self.equation, inputs, float_kernel)
|
|
928
|
+
else:
|
|
929
|
+
x = ops.einsum(self.equation, inputs, kernel)
|
|
930
|
+
# De-scale outputs
|
|
931
|
+
x = ops.cast(x, self.compute_dtype)
|
|
932
|
+
x = ops.divide(x, kernel_scale)
|
|
566
933
|
return x, grad_fn
|
|
567
934
|
|
|
568
935
|
x = einsum_with_inputs_gradient(
|
|
@@ -632,17 +999,38 @@ class EinsumDense(Layer):
|
|
|
632
999
|
return (inputs_grad, None, None)
|
|
633
1000
|
|
|
634
1001
|
# Quantize inputs per `self.inputs_quantizer`.
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
1002
|
+
if self.inputs_quantizer:
|
|
1003
|
+
inputs_q, inputs_scale = self.inputs_quantizer(
|
|
1004
|
+
inputs, axis=self.quantization_axis
|
|
1005
|
+
)
|
|
1006
|
+
# Align `inputs_scale` axes with the output
|
|
1007
|
+
# for correct broadcasting
|
|
1008
|
+
inputs_scale = self._adjust_scale_for_quant(
|
|
1009
|
+
inputs_scale, "input"
|
|
1010
|
+
)
|
|
1011
|
+
x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
|
|
1012
|
+
# De-scale outputs
|
|
1013
|
+
x = ops.cast(x, self.compute_dtype)
|
|
1014
|
+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
|
|
1015
|
+
else:
|
|
1016
|
+
# Weight-only quantization: dequantize kernel and use float
|
|
1017
|
+
# einsum. This is a workaround for PyTorch's einsum which
|
|
1018
|
+
# doesn't support mixed-precision inputs (float input,
|
|
1019
|
+
# int4 kernel).
|
|
1020
|
+
if backend.backend() == "torch":
|
|
1021
|
+
# Align `kernel_scale` to the same layout as
|
|
1022
|
+
# `unpacked_kernel`.
|
|
1023
|
+
kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
|
|
1024
|
+
float_kernel = ops.divide(
|
|
1025
|
+
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
|
|
1026
|
+
kernel_scale,
|
|
1027
|
+
)
|
|
1028
|
+
x = ops.einsum(self.equation, inputs, float_kernel)
|
|
1029
|
+
else:
|
|
1030
|
+
x = ops.einsum(self.equation, inputs, unpacked_kernel)
|
|
1031
|
+
# De-scale outputs
|
|
1032
|
+
x = ops.cast(x, self.compute_dtype)
|
|
1033
|
+
x = ops.divide(x, kernel_scale)
|
|
646
1034
|
return x, grad_fn
|
|
647
1035
|
|
|
648
1036
|
x = einsum_with_inputs_gradient(
|
|
@@ -756,30 +1144,40 @@ class EinsumDense(Layer):
|
|
|
756
1144
|
x = self.activation(x)
|
|
757
1145
|
return x
|
|
758
1146
|
|
|
759
|
-
def quantize(self, mode, type_check=True):
|
|
1147
|
+
def quantize(self, mode=None, type_check=True, config=None):
|
|
760
1148
|
# Prevent quantization of the subclasses
|
|
761
1149
|
if type_check and (type(self) is not EinsumDense):
|
|
762
1150
|
raise self._not_implemented_error(self.quantize)
|
|
763
1151
|
|
|
1152
|
+
self.quantization_config = config
|
|
1153
|
+
|
|
764
1154
|
kernel_shape = self._kernel.shape
|
|
765
|
-
if mode in ("int8", "int4"):
|
|
1155
|
+
if mode in ("int8", "int4", "gptq", "awq"):
|
|
766
1156
|
self._set_quantization_info()
|
|
767
1157
|
|
|
768
1158
|
if mode == "int8":
|
|
769
1159
|
# Quantize `self._kernel` to int8 and compute corresponding scale
|
|
770
|
-
|
|
771
|
-
self.
|
|
1160
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
1161
|
+
self.quantization_config,
|
|
1162
|
+
quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes),
|
|
1163
|
+
)
|
|
1164
|
+
kernel_value, kernel_scale = weight_quantizer(
|
|
1165
|
+
self._kernel, to_numpy=True
|
|
772
1166
|
)
|
|
773
1167
|
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
|
|
774
1168
|
del self._kernel
|
|
775
1169
|
elif mode == "int4":
|
|
776
1170
|
# Quantize to int4 values (stored in int8 dtype, range [-8, 7])
|
|
777
|
-
|
|
778
|
-
self.
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
1171
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
1172
|
+
self.quantization_config,
|
|
1173
|
+
quantizers.AbsMaxQuantizer(
|
|
1174
|
+
axis=self._kernel_reduced_axes,
|
|
1175
|
+
value_range=(-8, 7),
|
|
1176
|
+
output_dtype="int8",
|
|
1177
|
+
),
|
|
1178
|
+
)
|
|
1179
|
+
kernel_value_int4, kernel_scale = weight_quantizer(
|
|
1180
|
+
self._kernel, to_numpy=True
|
|
783
1181
|
)
|
|
784
1182
|
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
|
|
785
1183
|
|
|
@@ -790,7 +1188,7 @@ class EinsumDense(Layer):
|
|
|
790
1188
|
)
|
|
791
1189
|
kernel_value = packed_kernel_value
|
|
792
1190
|
del self._kernel
|
|
793
|
-
self.quantized_build(kernel_shape, mode)
|
|
1191
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
794
1192
|
|
|
795
1193
|
# Assign values to the newly created variables.
|
|
796
1194
|
if mode in ("int8", "int4"):
|
|
@@ -799,7 +1197,14 @@ class EinsumDense(Layer):
|
|
|
799
1197
|
|
|
800
1198
|
# Set new dtype policy
|
|
801
1199
|
if self.dtype_policy.quantization_mode is None:
|
|
802
|
-
|
|
1200
|
+
policy_name = mode
|
|
1201
|
+
if mode == "gptq":
|
|
1202
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
1203
|
+
elif mode == "awq":
|
|
1204
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
1205
|
+
policy = dtype_policies.get(
|
|
1206
|
+
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
1207
|
+
)
|
|
803
1208
|
self.dtype_policy = policy
|
|
804
1209
|
|
|
805
1210
|
def _get_kernel_scale_shape(self, kernel_shape):
|
|
@@ -860,7 +1265,7 @@ class EinsumDense(Layer):
|
|
|
860
1265
|
This is `None` if the layer is not quantized.
|
|
861
1266
|
"""
|
|
862
1267
|
# If not a quantized layer, return the full-precision kernel directly.
|
|
863
|
-
if self.dtype_policy.quantization_mode
|
|
1268
|
+
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
864
1269
|
return self.kernel, None
|
|
865
1270
|
|
|
866
1271
|
# If quantized but LoRA is not enabled, return the original quantized
|