keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +12 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +12 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +1 -1
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +33 -16
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +485 -20
- keras/src/backend/jax/numpy.py +92 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +76 -7
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1030 -185
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +264 -54
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +84 -8
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +299 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +191 -172
- keras/src/layers/core/einsum_dense.py +235 -186
- keras/src/layers/core/embedding.py +83 -93
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +390 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +40 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/losses/loss.py +1 -1
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +156 -27
- keras/src/ops/image.py +184 -3
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +541 -43
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +12 -1
- keras/src/quantizers/gptq.py +8 -6
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +150 -78
- keras/src/quantizers/quantization_config.py +232 -0
- keras/src/quantizers/quantizers.py +114 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +4 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +14 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +187 -36
- keras/src/utils/module_utils.py +18 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/rng_utils.py +9 -1
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
|
@@ -6,10 +6,13 @@ from absl import logging
|
|
|
6
6
|
|
|
7
7
|
from keras.src import ops
|
|
8
8
|
from keras.src import utils as keras_utils
|
|
9
|
+
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
|
|
10
|
+
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
9
11
|
from keras.src.layers import Dense
|
|
10
12
|
from keras.src.layers import EinsumDense
|
|
11
|
-
from keras.src.layers import Embedding
|
|
12
13
|
from keras.src.quantizers.gptq import GPTQ
|
|
14
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
|
15
|
+
from keras.src.quantizers.utils import should_quantize_layer
|
|
13
16
|
|
|
14
17
|
|
|
15
18
|
@contextmanager
|
|
@@ -190,38 +193,6 @@ def get_dataloader(
|
|
|
190
193
|
return samples.astype(np.int32)[:, None, :]
|
|
191
194
|
|
|
192
195
|
|
|
193
|
-
def _get_backbone_layers(model):
|
|
194
|
-
"""Extract embedding and transformer layers from a KerasHub model."""
|
|
195
|
-
backbone = model.backbone
|
|
196
|
-
if not hasattr(backbone, "transformer_layers"):
|
|
197
|
-
raise ValueError(
|
|
198
|
-
"The model's backbone does not have a 'transformer_layers' "
|
|
199
|
-
"attribute. Please ensure you are using a standard KerasHub "
|
|
200
|
-
"transformer model."
|
|
201
|
-
)
|
|
202
|
-
transformer_blocks = backbone.transformer_layers
|
|
203
|
-
|
|
204
|
-
embedding_layer = None
|
|
205
|
-
if hasattr(backbone, "token_embedding"):
|
|
206
|
-
embedding_layer = backbone.token_embedding
|
|
207
|
-
elif hasattr(backbone, "embedding"):
|
|
208
|
-
embedding_layer = backbone.embedding
|
|
209
|
-
return embedding_layer, transformer_blocks
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
def _get_custom_layers(model):
|
|
213
|
-
"""Heuristic for extracting embedding + transformer blocks from a custom
|
|
214
|
-
model."""
|
|
215
|
-
embedding_layer = None
|
|
216
|
-
transformer_blocks = []
|
|
217
|
-
for layer in model.layers:
|
|
218
|
-
if isinstance(layer, Embedding) and embedding_layer is None:
|
|
219
|
-
embedding_layer = layer
|
|
220
|
-
elif getattr(layer, "_layers", None): # container-like block
|
|
221
|
-
transformer_blocks.append(layer)
|
|
222
|
-
return embedding_layer, transformer_blocks
|
|
223
|
-
|
|
224
|
-
|
|
225
196
|
def find_layers_in_block(block):
|
|
226
197
|
"""
|
|
227
198
|
Finds all Dense and EinsumDense layers in a transformer block.
|
|
@@ -239,39 +210,31 @@ def find_layers_in_block(block):
|
|
|
239
210
|
return found_layers
|
|
240
211
|
|
|
241
212
|
|
|
242
|
-
def apply_gptq_layerwise(
|
|
213
|
+
def apply_gptq_layerwise(dataloader, config, structure, filters=None):
|
|
243
214
|
"""Applies GPTQ quantization layer-by-layer to a Keras model.
|
|
244
215
|
|
|
245
|
-
This function
|
|
246
|
-
|
|
247
|
-
structure by first looking for the standard format: a `model.backbone`
|
|
248
|
-
attribute that contains a `transformer_layers` list.
|
|
249
|
-
|
|
250
|
-
If a standard backbone is not found, it falls back to a heuristic for
|
|
251
|
-
custom models, where it assumes the first `keras.layers.Embedding` layer
|
|
252
|
-
is the input embedding and any subsequent container layers are the
|
|
253
|
-
transformer blocks to be quantized.
|
|
216
|
+
This function uses the provided `structure` to identify pre-quantization
|
|
217
|
+
layers and sequential blocks.
|
|
254
218
|
|
|
255
219
|
The core logic operates as follows:
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
2. It processes the model sequentially, one block at a time. For each
|
|
220
|
+
|
|
221
|
+
1. It processes the model sequentially, one block at a time. For each
|
|
259
222
|
block, it uses temporary hooks to capture the input activations of
|
|
260
223
|
each target layer during a forward pass with the calibration data.
|
|
261
|
-
|
|
224
|
+
2. These captured activations are used to compute the Hessian matrix for
|
|
262
225
|
each layer's weights.
|
|
263
|
-
|
|
226
|
+
3. The GPTQ algorithm is then applied to each layer to find the optimal
|
|
264
227
|
quantized weights that minimize the error introduced.
|
|
265
|
-
|
|
228
|
+
4. The output activations from the current block are then used as the
|
|
266
229
|
input for the next block, ensuring that quantization errors are
|
|
267
230
|
accounted for throughout the model.
|
|
268
231
|
|
|
269
232
|
Args:
|
|
270
|
-
|
|
271
|
-
attempt to automatically discover its structure.
|
|
272
|
-
dataloader: An iterable providing calibration data. Each item should
|
|
273
|
-
be a batch of token IDs suitable for the model's embedding layer.
|
|
233
|
+
dataloader: An iterable providing calibration data.
|
|
274
234
|
config: A GPTQConfiguration object.
|
|
235
|
+
structure: A dictionary with keys "pre_block_layers" and
|
|
236
|
+
"sequential_blocks".
|
|
237
|
+
filters: Optional filters to exclude layers from quantization.
|
|
275
238
|
|
|
276
239
|
Raises:
|
|
277
240
|
ValueError: If the function cannot automatically find an embedding
|
|
@@ -281,30 +244,23 @@ def apply_gptq_layerwise(model, dataloader, config):
|
|
|
281
244
|
num_samples = config.num_samples
|
|
282
245
|
|
|
283
246
|
logging.info("Starting model quantization...")
|
|
284
|
-
embedding_layer = None
|
|
285
|
-
transformer_blocks = []
|
|
286
|
-
if hasattr(model, "backbone"):
|
|
287
|
-
logging.info("Detected KerasHub model structure.")
|
|
288
|
-
embedding_layer, transformer_blocks = _get_backbone_layers(model)
|
|
289
|
-
else:
|
|
290
|
-
logging.info("Detected custom model structure.")
|
|
291
|
-
embedding_layer, transformer_blocks = _get_custom_layers(model)
|
|
292
247
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
)
|
|
248
|
+
pre_layers = structure.get("pre_block_layers", [])
|
|
249
|
+
transformer_blocks = structure.get("sequential_blocks", [])
|
|
250
|
+
|
|
297
251
|
if not transformer_blocks:
|
|
298
252
|
raise ValueError(
|
|
299
|
-
"
|
|
300
|
-
"quantize."
|
|
253
|
+
"No sequential blocks found in the provided structure to quantize."
|
|
301
254
|
)
|
|
302
255
|
|
|
303
|
-
# Initial inputs are the outputs of the
|
|
304
|
-
inputs = [
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
256
|
+
# Initial inputs are the outputs of the pre-block layers
|
|
257
|
+
inputs = []
|
|
258
|
+
for batch in dataloader:
|
|
259
|
+
batch = ops.convert_to_tensor(batch, dtype="int32")
|
|
260
|
+
for layer in pre_layers:
|
|
261
|
+
batch = layer(batch)
|
|
262
|
+
inputs.append(batch)
|
|
263
|
+
|
|
308
264
|
num_samples = min(num_samples, len(inputs))
|
|
309
265
|
|
|
310
266
|
progbar = keras_utils.Progbar(target=len(transformer_blocks))
|
|
@@ -313,10 +269,19 @@ def apply_gptq_layerwise(model, dataloader, config):
|
|
|
313
269
|
logging.info(f"Quantizing Block {block_idx}")
|
|
314
270
|
sub_layers_map = find_layers_in_block(block)
|
|
315
271
|
|
|
272
|
+
# Filter out layers that are not quantized with GPTQ
|
|
273
|
+
final_sub_layers_map = {}
|
|
274
|
+
for name, layer in sub_layers_map.items():
|
|
275
|
+
if not should_quantize_layer(layer, filters):
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
final_sub_layers_map[name] = layer
|
|
279
|
+
|
|
280
|
+
sub_layers_map = final_sub_layers_map
|
|
281
|
+
|
|
316
282
|
if not sub_layers_map:
|
|
317
283
|
logging.info(
|
|
318
|
-
f" No
|
|
319
|
-
"Skipping."
|
|
284
|
+
f" No quantizable layers found in block {block_idx}. Skipping."
|
|
320
285
|
)
|
|
321
286
|
else:
|
|
322
287
|
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
|
|
@@ -354,11 +319,30 @@ def apply_gptq_layerwise(model, dataloader, config):
|
|
|
354
319
|
logging.info("Quantization process complete.")
|
|
355
320
|
|
|
356
321
|
|
|
357
|
-
def gptq_quantize(
|
|
322
|
+
def gptq_quantize(config, quantization_layer_structure, filters=None):
|
|
358
323
|
"""
|
|
359
|
-
|
|
324
|
+
Quantizes the model using GPTQ.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
config: The GPTQ configuration.
|
|
328
|
+
quantization_layer_structure: A dictionary describing the model's layer
|
|
329
|
+
structure for quantization.
|
|
330
|
+
filters: Optional filters to exclude layers from quantization.
|
|
360
331
|
"""
|
|
361
|
-
|
|
332
|
+
if config.dataset is None or config.tokenizer is None:
|
|
333
|
+
raise ValueError(
|
|
334
|
+
"GPTQ quantization requires a dataset and a tokenizer. "
|
|
335
|
+
"Please provide them in the `GPTQConfig`."
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if quantization_layer_structure is None:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
"For 'gptq' mode, a valid quantization structure must be provided "
|
|
341
|
+
"either via `config.quantization_layer_structure` or by overriding "
|
|
342
|
+
"`model.get_quantization_layer_structure(mode)`. The structure "
|
|
343
|
+
"should be a dictionary with keys 'pre_block_layers' and "
|
|
344
|
+
"'sequential_blocks'."
|
|
345
|
+
)
|
|
362
346
|
|
|
363
347
|
# Load all data needed from the generator/source in a single call.
|
|
364
348
|
total_samples_to_request = config.num_samples
|
|
@@ -373,4 +357,92 @@ def gptq_quantize(model, config):
|
|
|
373
357
|
# is now a NumPy array, which can be sliced and reused.
|
|
374
358
|
calibration_dataloader = dataloader[: config.num_samples]
|
|
375
359
|
|
|
376
|
-
apply_gptq_layerwise(
|
|
360
|
+
apply_gptq_layerwise(
|
|
361
|
+
calibration_dataloader,
|
|
362
|
+
config,
|
|
363
|
+
quantization_layer_structure,
|
|
364
|
+
filters=filters,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def get_group_size_for_layer(layer, config):
|
|
369
|
+
"""Determine the group size for GPTQ quantization.
|
|
370
|
+
|
|
371
|
+
The group size can be specified either through the `config` argument
|
|
372
|
+
or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.
|
|
373
|
+
|
|
374
|
+
The config argument is usually available when quantizing the layer
|
|
375
|
+
via the `quantize` method. If the layer was deserialized from a
|
|
376
|
+
saved model, the group size should be specified in the `dtype_policy`.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
config: An optional configuration object that may contain the
|
|
380
|
+
`group_size` attribute.
|
|
381
|
+
Returns:
|
|
382
|
+
int. The determined group size for GPTQ quantization.
|
|
383
|
+
Raises:
|
|
384
|
+
ValueError: If the group size is not specified in either the
|
|
385
|
+
`config` or the `dtype_policy`.
|
|
386
|
+
"""
|
|
387
|
+
if config and isinstance(config, GPTQConfig):
|
|
388
|
+
return config.group_size
|
|
389
|
+
elif isinstance(layer.dtype_policy, GPTQDTypePolicy):
|
|
390
|
+
return layer.dtype_policy.group_size
|
|
391
|
+
elif isinstance(layer.dtype_policy, DTypePolicyMap):
|
|
392
|
+
policy = layer.dtype_policy[layer.path]
|
|
393
|
+
if not isinstance(policy, GPTQDTypePolicy):
|
|
394
|
+
# This should never happen based on how we set the
|
|
395
|
+
# quantization mode, but we check just in case.
|
|
396
|
+
raise ValueError(
|
|
397
|
+
"Expected a `dtype_policy` of type `GPTQDTypePolicy`."
|
|
398
|
+
f"Got: {type(policy)}"
|
|
399
|
+
)
|
|
400
|
+
return policy.group_size
|
|
401
|
+
else:
|
|
402
|
+
raise ValueError(
|
|
403
|
+
"For GPTQ quantization, the group_size must be specified"
|
|
404
|
+
"either through a `dtype_policy` of type "
|
|
405
|
+
"`GPTQDTypePolicy` or the `config` argument."
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def get_weight_bits_for_layer(layer, config):
|
|
410
|
+
"""Determine the number of weight bits for GPTQ quantization.
|
|
411
|
+
|
|
412
|
+
The number of weight bits can be specified either through the `config`
|
|
413
|
+
argument or through the `dtype_policy` if it is of type
|
|
414
|
+
`GPTQDTypePolicy`.
|
|
415
|
+
|
|
416
|
+
The config argument is usually available when quantizing the layer
|
|
417
|
+
via the `quantize` method. If the layer was deserialized from a
|
|
418
|
+
saved model, the weight bits should be specified in the `dtype_policy`.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
config: An optional configuration object that may contain the
|
|
422
|
+
`weight_bits` attribute.
|
|
423
|
+
Returns:
|
|
424
|
+
int. The determined number of weight bits for GPTQ quantization.
|
|
425
|
+
Raises:
|
|
426
|
+
ValueError: If the weight bits is not specified in either the
|
|
427
|
+
`config` or the `dtype_policy`.
|
|
428
|
+
"""
|
|
429
|
+
if config and isinstance(config, GPTQConfig):
|
|
430
|
+
return config.weight_bits
|
|
431
|
+
elif isinstance(layer.dtype_policy, GPTQDTypePolicy):
|
|
432
|
+
return layer.dtype_policy.weight_bits
|
|
433
|
+
elif isinstance(layer.dtype_policy, DTypePolicyMap):
|
|
434
|
+
policy = layer.dtype_policy[layer.path]
|
|
435
|
+
if not isinstance(policy, GPTQDTypePolicy):
|
|
436
|
+
# This should never happen based on how we set the
|
|
437
|
+
# quantization mode, but we check just in case.
|
|
438
|
+
raise ValueError(
|
|
439
|
+
"Expected a `dtype_policy` of type `GPTQDTypePolicy`."
|
|
440
|
+
f"Got: {type(policy)}"
|
|
441
|
+
)
|
|
442
|
+
return policy.weight_bits
|
|
443
|
+
else:
|
|
444
|
+
raise ValueError(
|
|
445
|
+
"For GPTQ quantization, the weight_bits must be specified"
|
|
446
|
+
"either through a `dtype_policy` of type "
|
|
447
|
+
"`GPTQDTypePolicy` or the `config` argument."
|
|
448
|
+
)
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from keras.src.api_export import keras_export
|
|
2
|
+
from keras.src.dtype_policies import QUANTIZATION_MODES
|
|
3
|
+
from keras.src.saving import serialization_lib
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@keras_export("keras.quantizers.QuantizationConfig")
|
|
7
|
+
class QuantizationConfig:
|
|
8
|
+
"""Base class for quantization configs.
|
|
9
|
+
|
|
10
|
+
Subclasses must implement the `mode` property and the `get_config` and
|
|
11
|
+
`from_config` class methods.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
weight_quantizer: Quantizer for weights.
|
|
15
|
+
activation_quantizer: Quantizer for activations.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, weight_quantizer=None, activation_quantizer=None):
|
|
19
|
+
self.weight_quantizer = weight_quantizer
|
|
20
|
+
self.activation_quantizer = activation_quantizer
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def mode(self):
|
|
24
|
+
raise NotImplementedError(
|
|
25
|
+
"Subclasses must implement this property. Do not instantiate "
|
|
26
|
+
"QuantizationConfig directly."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def get_config(self):
|
|
30
|
+
return {
|
|
31
|
+
"weight_quantizer": serialization_lib.serialize_keras_object(
|
|
32
|
+
self.weight_quantizer
|
|
33
|
+
),
|
|
34
|
+
"activation_quantizer": serialization_lib.serialize_keras_object(
|
|
35
|
+
self.activation_quantizer
|
|
36
|
+
),
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_config(cls, config):
|
|
41
|
+
weight_quantizer = serialization_lib.deserialize_keras_object(
|
|
42
|
+
config.get("weight_quantizer")
|
|
43
|
+
)
|
|
44
|
+
activation_quantizer = serialization_lib.deserialize_keras_object(
|
|
45
|
+
config.get("activation_quantizer")
|
|
46
|
+
)
|
|
47
|
+
return cls(
|
|
48
|
+
weight_quantizer=weight_quantizer,
|
|
49
|
+
activation_quantizer=activation_quantizer,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def weight_quantizer_or_default(config, default):
|
|
54
|
+
if config is not None and config.weight_quantizer is not None:
|
|
55
|
+
return config.weight_quantizer
|
|
56
|
+
return default
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def activation_quantizer_or_default(config, default):
|
|
60
|
+
if config is not None:
|
|
61
|
+
return config.activation_quantizer
|
|
62
|
+
return default
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@keras_export("keras.quantizers.Int8QuantizationConfig")
|
|
66
|
+
class Int8QuantizationConfig(QuantizationConfig):
|
|
67
|
+
"""Int8 quantization config.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
weight_quantizer: Quantizer for weights.
|
|
71
|
+
activation_quantizer: Quantizer for activations. If "default", uses
|
|
72
|
+
AbsMaxQuantizer with axis=-1.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, weight_quantizer=None, activation_quantizer="default"):
|
|
76
|
+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
|
|
77
|
+
|
|
78
|
+
if activation_quantizer == "default":
|
|
79
|
+
activation_quantizer = AbsMaxQuantizer()
|
|
80
|
+
super().__init__(weight_quantizer, activation_quantizer)
|
|
81
|
+
if self.weight_quantizer is not None:
|
|
82
|
+
if self.weight_quantizer.output_dtype != "int8":
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Int8QuantizationConfig requires a weight_quantizer "
|
|
85
|
+
"with output_dtype='int8'. Received: "
|
|
86
|
+
f"output_dtype={self.weight_quantizer.output_dtype}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def mode(self):
|
|
91
|
+
return "int8"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@keras_export("keras.quantizers.Int4QuantizationConfig")
|
|
95
|
+
class Int4QuantizationConfig(QuantizationConfig):
|
|
96
|
+
"""Int4 quantization config.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
weight_quantizer: Quantizer for weights.
|
|
100
|
+
activation_quantizer: Quantizer for activations. If "default", uses
|
|
101
|
+
AbsMaxQuantizer with axis=-1.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, weight_quantizer=None, activation_quantizer="default"):
|
|
105
|
+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
|
|
106
|
+
|
|
107
|
+
if activation_quantizer == "default":
|
|
108
|
+
activation_quantizer = AbsMaxQuantizer()
|
|
109
|
+
super().__init__(weight_quantizer, activation_quantizer)
|
|
110
|
+
if self.weight_quantizer is not None:
|
|
111
|
+
if self.weight_quantizer.value_range != (-8, 7):
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"Int4QuantizationConfig requires a weight_quantizer "
|
|
114
|
+
"with value_range=(-8, 7). Received: "
|
|
115
|
+
f"value_range={self.weight_quantizer.value_range}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if self.weight_quantizer.output_dtype != "int8":
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"Int4QuantizationConfig requires a weight_quantizer "
|
|
121
|
+
"with output_dtype='int8'. Received: "
|
|
122
|
+
f"output_dtype={self.weight_quantizer.output_dtype}"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def mode(self):
|
|
127
|
+
return "int4"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@keras_export("keras.quantizers.Float8QuantizationConfig")
|
|
131
|
+
class Float8QuantizationConfig(QuantizationConfig):
|
|
132
|
+
"""FP8 quantization config.
|
|
133
|
+
|
|
134
|
+
FP8 mixed-precision training does not support user defined quantizers.
|
|
135
|
+
This config is only used to indicate that FP8 mixed-precision training
|
|
136
|
+
should be used.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(self):
|
|
140
|
+
super().__init__(None, None)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def mode(self):
|
|
144
|
+
return "float8"
|
|
145
|
+
|
|
146
|
+
def get_config(self):
|
|
147
|
+
return {}
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def from_config(cls, config):
|
|
151
|
+
return cls()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def validate_and_resolve_config(mode, config):
|
|
155
|
+
"""Validate and resolve quantization config.
|
|
156
|
+
|
|
157
|
+
This function validates the quantization config and resolves the mode.
|
|
158
|
+
If mode is not provided, it is inferred from the config.
|
|
159
|
+
If config is not provided, a default config is inferred from the mode.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
mode: Quantization mode.
|
|
163
|
+
config: Quantization config.
|
|
164
|
+
"""
|
|
165
|
+
# 1. Backwards Compatibility: Handle string shortcuts.
|
|
166
|
+
if isinstance(config, str):
|
|
167
|
+
mode = config
|
|
168
|
+
config = None
|
|
169
|
+
|
|
170
|
+
_validate_mode(mode)
|
|
171
|
+
|
|
172
|
+
# 2. Resolve "mode" into a Config object.
|
|
173
|
+
if config is None:
|
|
174
|
+
if mode == "int8":
|
|
175
|
+
config = Int8QuantizationConfig()
|
|
176
|
+
elif mode == "int4":
|
|
177
|
+
config = Int4QuantizationConfig()
|
|
178
|
+
elif mode == "float8":
|
|
179
|
+
config = Float8QuantizationConfig()
|
|
180
|
+
elif mode == "gptq":
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"For GPTQ, you must pass a `GPTQConfig` object in the "
|
|
183
|
+
"`config` argument."
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
if mode is not None:
|
|
187
|
+
raise ValueError(
|
|
188
|
+
f"Invalid quantization mode. Received: mode={mode}"
|
|
189
|
+
)
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"You must provide either `mode` or `config` to `quantize`."
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
if not isinstance(config, QuantizationConfig):
|
|
195
|
+
raise ValueError(
|
|
196
|
+
"Argument `config` must be an instance of "
|
|
197
|
+
"`QuantizationConfig`. "
|
|
198
|
+
f"Received: config={config} (of type {type(config)})"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# 3. Validation: Prevent contradictions.
|
|
202
|
+
if mode is not None and config.mode != mode:
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"Contradictory arguments: mode='{mode}' but "
|
|
205
|
+
f"config.mode='{config.mode}'"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Ensure mode is consistent.
|
|
209
|
+
mode = config.mode
|
|
210
|
+
|
|
211
|
+
# Ensure the mode derived from the config is valid.
|
|
212
|
+
_validate_mode(mode)
|
|
213
|
+
|
|
214
|
+
if mode == "gptq":
|
|
215
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
|
216
|
+
|
|
217
|
+
if not isinstance(config, GPTQConfig):
|
|
218
|
+
raise ValueError(
|
|
219
|
+
"Mode 'gptq' requires a valid `config` argument of type "
|
|
220
|
+
f"`GPTQConfig`. Received: {type(config)}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return config
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _validate_mode(mode):
|
|
227
|
+
"""Validates quantization mode."""
|
|
228
|
+
if mode is not None and mode not in QUANTIZATION_MODES:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
"Invalid quantization mode. "
|
|
231
|
+
f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
|
|
232
|
+
)
|