keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -6,10 +6,13 @@ from absl import logging
6
6
 
7
7
  from keras.src import ops
8
8
  from keras.src import utils as keras_utils
9
+ from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
10
+ from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
9
11
  from keras.src.layers import Dense
10
12
  from keras.src.layers import EinsumDense
11
- from keras.src.layers import Embedding
12
13
  from keras.src.quantizers.gptq import GPTQ
14
+ from keras.src.quantizers.gptq_config import GPTQConfig
15
+ from keras.src.quantizers.utils import should_quantize_layer
13
16
 
14
17
 
15
18
  @contextmanager
@@ -190,38 +193,6 @@ def get_dataloader(
190
193
  return samples.astype(np.int32)[:, None, :]
191
194
 
192
195
 
193
- def _get_backbone_layers(model):
194
- """Extract embedding and transformer layers from a KerasHub model."""
195
- backbone = model.backbone
196
- if not hasattr(backbone, "transformer_layers"):
197
- raise ValueError(
198
- "The model's backbone does not have a 'transformer_layers' "
199
- "attribute. Please ensure you are using a standard KerasHub "
200
- "transformer model."
201
- )
202
- transformer_blocks = backbone.transformer_layers
203
-
204
- embedding_layer = None
205
- if hasattr(backbone, "token_embedding"):
206
- embedding_layer = backbone.token_embedding
207
- elif hasattr(backbone, "embedding"):
208
- embedding_layer = backbone.embedding
209
- return embedding_layer, transformer_blocks
210
-
211
-
212
- def _get_custom_layers(model):
213
- """Heuristic for extracting embedding + transformer blocks from a custom
214
- model."""
215
- embedding_layer = None
216
- transformer_blocks = []
217
- for layer in model.layers:
218
- if isinstance(layer, Embedding) and embedding_layer is None:
219
- embedding_layer = layer
220
- elif getattr(layer, "_layers", None): # container-like block
221
- transformer_blocks.append(layer)
222
- return embedding_layer, transformer_blocks
223
-
224
-
225
196
  def find_layers_in_block(block):
226
197
  """
227
198
  Finds all Dense and EinsumDense layers in a transformer block.
@@ -239,39 +210,31 @@ def find_layers_in_block(block):
239
210
  return found_layers
240
211
 
241
212
 
242
- def apply_gptq_layerwise(model, dataloader, config):
213
+ def apply_gptq_layerwise(dataloader, config, structure, filters=None):
243
214
  """Applies GPTQ quantization layer-by-layer to a Keras model.
244
215
 
245
- This function is designed to work with common transformer architectures,
246
- like those provided by KerasHub. It automatically discovers the model's
247
- structure by first looking for the standard format: a `model.backbone`
248
- attribute that contains a `transformer_layers` list.
249
-
250
- If a standard backbone is not found, it falls back to a heuristic for
251
- custom models, where it assumes the first `keras.layers.Embedding` layer
252
- is the input embedding and any subsequent container layers are the
253
- transformer blocks to be quantized.
216
+ This function uses the provided `structure` to identify pre-quantization
217
+ layers and sequential blocks.
254
218
 
255
219
  The core logic operates as follows:
256
- 1. It automatically detects the model's structure, identifying the main
257
- embedding layer and a sequence of transformer blocks.
258
- 2. It processes the model sequentially, one block at a time. For each
220
+
221
+ 1. It processes the model sequentially, one block at a time. For each
259
222
  block, it uses temporary hooks to capture the input activations of
260
223
  each target layer during a forward pass with the calibration data.
261
- 3. These captured activations are used to compute the Hessian matrix for
224
+ 2. These captured activations are used to compute the Hessian matrix for
262
225
  each layer's weights.
263
- 4. The GPTQ algorithm is then applied to each layer to find the optimal
226
+ 3. The GPTQ algorithm is then applied to each layer to find the optimal
264
227
  quantized weights that minimize the error introduced.
265
- 5. The output activations from the current block are then used as the
228
+ 4. The output activations from the current block are then used as the
266
229
  input for the next block, ensuring that quantization errors are
267
230
  accounted for throughout the model.
268
231
 
269
232
  Args:
270
- model: The Keras model instance to be quantized. The function will
271
- attempt to automatically discover its structure.
272
- dataloader: An iterable providing calibration data. Each item should
273
- be a batch of token IDs suitable for the model's embedding layer.
233
+ dataloader: An iterable providing calibration data.
274
234
  config: A GPTQConfiguration object.
235
+ structure: A dictionary with keys "pre_block_layers" and
236
+ "sequential_blocks".
237
+ filters: Optional filters to exclude layers from quantization.
275
238
 
276
239
  Raises:
277
240
  ValueError: If the function cannot automatically find an embedding
@@ -281,30 +244,23 @@ def apply_gptq_layerwise(model, dataloader, config):
281
244
  num_samples = config.num_samples
282
245
 
283
246
  logging.info("Starting model quantization...")
284
- embedding_layer = None
285
- transformer_blocks = []
286
- if hasattr(model, "backbone"):
287
- logging.info("Detected KerasHub model structure.")
288
- embedding_layer, transformer_blocks = _get_backbone_layers(model)
289
- else:
290
- logging.info("Detected custom model structure.")
291
- embedding_layer, transformer_blocks = _get_custom_layers(model)
292
247
 
293
- if embedding_layer is None:
294
- raise ValueError(
295
- "Could not automatically find an embedding layer in the model."
296
- )
248
+ pre_layers = structure.get("pre_block_layers", [])
249
+ transformer_blocks = structure.get("sequential_blocks", [])
250
+
297
251
  if not transformer_blocks:
298
252
  raise ValueError(
299
- "Could not automatically find any transformer-like blocks to "
300
- "quantize."
253
+ "No sequential blocks found in the provided structure to quantize."
301
254
  )
302
255
 
303
- # Initial inputs are the outputs of the token embedding layer
304
- inputs = [
305
- embedding_layer(ops.convert_to_tensor(batch, dtype="int32"))
306
- for batch in dataloader
307
- ]
256
+ # Initial inputs are the outputs of the pre-block layers
257
+ inputs = []
258
+ for batch in dataloader:
259
+ batch = ops.convert_to_tensor(batch, dtype="int32")
260
+ for layer in pre_layers:
261
+ batch = layer(batch)
262
+ inputs.append(batch)
263
+
308
264
  num_samples = min(num_samples, len(inputs))
309
265
 
310
266
  progbar = keras_utils.Progbar(target=len(transformer_blocks))
@@ -313,10 +269,19 @@ def apply_gptq_layerwise(model, dataloader, config):
313
269
  logging.info(f"Quantizing Block {block_idx}")
314
270
  sub_layers_map = find_layers_in_block(block)
315
271
 
272
+ # Filter out layers that are not quantized with GPTQ
273
+ final_sub_layers_map = {}
274
+ for name, layer in sub_layers_map.items():
275
+ if not should_quantize_layer(layer, filters):
276
+ continue
277
+
278
+ final_sub_layers_map[name] = layer
279
+
280
+ sub_layers_map = final_sub_layers_map
281
+
316
282
  if not sub_layers_map:
317
283
  logging.info(
318
- f" No Dense or EinsumDense layers found in block {block_idx}. "
319
- "Skipping."
284
+ f" No quantizable layers found in block {block_idx}. Skipping."
320
285
  )
321
286
  else:
322
287
  logging.info(f"Found layers: {list(sub_layers_map.keys())}")
@@ -354,11 +319,30 @@ def apply_gptq_layerwise(model, dataloader, config):
354
319
  logging.info("Quantization process complete.")
355
320
 
356
321
 
357
- def gptq_quantize(model, config):
322
+ def gptq_quantize(config, quantization_layer_structure, filters=None):
358
323
  """
359
- Top-level function to quantize a Keras model using GPTQ.
324
+ Quantizes the model using GPTQ.
325
+
326
+ Args:
327
+ config: The GPTQ configuration.
328
+ quantization_layer_structure: A dictionary describing the model's layer
329
+ structure for quantization.
330
+ filters: Optional filters to exclude layers from quantization.
360
331
  """
361
- logging.info("Starting GPTQ quantization process...")
332
+ if config.dataset is None or config.tokenizer is None:
333
+ raise ValueError(
334
+ "GPTQ quantization requires a dataset and a tokenizer. "
335
+ "Please provide them in the `GPTQConfig`."
336
+ )
337
+
338
+ if quantization_layer_structure is None:
339
+ raise ValueError(
340
+ "For 'gptq' mode, a valid quantization structure must be provided "
341
+ "either via `config.quantization_layer_structure` or by overriding "
342
+ "`model.get_quantization_layer_structure(mode)`. The structure "
343
+ "should be a dictionary with keys 'pre_block_layers' and "
344
+ "'sequential_blocks'."
345
+ )
362
346
 
363
347
  # Load all data needed from the generator/source in a single call.
364
348
  total_samples_to_request = config.num_samples
@@ -373,4 +357,92 @@ def gptq_quantize(model, config):
373
357
  # is now a NumPy array, which can be sliced and reused.
374
358
  calibration_dataloader = dataloader[: config.num_samples]
375
359
 
376
- apply_gptq_layerwise(model, calibration_dataloader, config)
360
+ apply_gptq_layerwise(
361
+ calibration_dataloader,
362
+ config,
363
+ quantization_layer_structure,
364
+ filters=filters,
365
+ )
366
+
367
+
368
+ def get_group_size_for_layer(layer, config):
369
+ """Determine the group size for GPTQ quantization.
370
+
371
+ The group size can be specified either through the `config` argument
372
+ or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.
373
+
374
+ The config argument is usually available when quantizing the layer
375
+ via the `quantize` method. If the layer was deserialized from a
376
+ saved model, the group size should be specified in the `dtype_policy`.
377
+
378
+ Args:
379
+ config: An optional configuration object that may contain the
380
+ `group_size` attribute.
381
+ Returns:
382
+ int. The determined group size for GPTQ quantization.
383
+ Raises:
384
+ ValueError: If the group size is not specified in either the
385
+ `config` or the `dtype_policy`.
386
+ """
387
+ if config and isinstance(config, GPTQConfig):
388
+ return config.group_size
389
+ elif isinstance(layer.dtype_policy, GPTQDTypePolicy):
390
+ return layer.dtype_policy.group_size
391
+ elif isinstance(layer.dtype_policy, DTypePolicyMap):
392
+ policy = layer.dtype_policy[layer.path]
393
+ if not isinstance(policy, GPTQDTypePolicy):
394
+ # This should never happen based on how we set the
395
+ # quantization mode, but we check just in case.
396
+ raise ValueError(
397
+ "Expected a `dtype_policy` of type `GPTQDTypePolicy`."
398
+ f"Got: {type(policy)}"
399
+ )
400
+ return policy.group_size
401
+ else:
402
+ raise ValueError(
403
+ "For GPTQ quantization, the group_size must be specified"
404
+ "either through a `dtype_policy` of type "
405
+ "`GPTQDTypePolicy` or the `config` argument."
406
+ )
407
+
408
+
409
+ def get_weight_bits_for_layer(layer, config):
410
+ """Determine the number of weight bits for GPTQ quantization.
411
+
412
+ The number of weight bits can be specified either through the `config`
413
+ argument or through the `dtype_policy` if it is of type
414
+ `GPTQDTypePolicy`.
415
+
416
+ The config argument is usually available when quantizing the layer
417
+ via the `quantize` method. If the layer was deserialized from a
418
+ saved model, the weight bits should be specified in the `dtype_policy`.
419
+
420
+ Args:
421
+ config: An optional configuration object that may contain the
422
+ `weight_bits` attribute.
423
+ Returns:
424
+ int. The determined number of weight bits for GPTQ quantization.
425
+ Raises:
426
+ ValueError: If the weight bits is not specified in either the
427
+ `config` or the `dtype_policy`.
428
+ """
429
+ if config and isinstance(config, GPTQConfig):
430
+ return config.weight_bits
431
+ elif isinstance(layer.dtype_policy, GPTQDTypePolicy):
432
+ return layer.dtype_policy.weight_bits
433
+ elif isinstance(layer.dtype_policy, DTypePolicyMap):
434
+ policy = layer.dtype_policy[layer.path]
435
+ if not isinstance(policy, GPTQDTypePolicy):
436
+ # This should never happen based on how we set the
437
+ # quantization mode, but we check just in case.
438
+ raise ValueError(
439
+ "Expected a `dtype_policy` of type `GPTQDTypePolicy`."
440
+ f"Got: {type(policy)}"
441
+ )
442
+ return policy.weight_bits
443
+ else:
444
+ raise ValueError(
445
+ "For GPTQ quantization, the weight_bits must be specified"
446
+ "either through a `dtype_policy` of type "
447
+ "`GPTQDTypePolicy` or the `config` argument."
448
+ )
@@ -0,0 +1,232 @@
1
+ from keras.src.api_export import keras_export
2
+ from keras.src.dtype_policies import QUANTIZATION_MODES
3
+ from keras.src.saving import serialization_lib
4
+
5
+
6
+ @keras_export("keras.quantizers.QuantizationConfig")
7
+ class QuantizationConfig:
8
+ """Base class for quantization configs.
9
+
10
+ Subclasses must implement the `mode` property and the `get_config` and
11
+ `from_config` class methods.
12
+
13
+ Args:
14
+ weight_quantizer: Quantizer for weights.
15
+ activation_quantizer: Quantizer for activations.
16
+ """
17
+
18
+ def __init__(self, weight_quantizer=None, activation_quantizer=None):
19
+ self.weight_quantizer = weight_quantizer
20
+ self.activation_quantizer = activation_quantizer
21
+
22
+ @property
23
+ def mode(self):
24
+ raise NotImplementedError(
25
+ "Subclasses must implement this property. Do not instantiate "
26
+ "QuantizationConfig directly."
27
+ )
28
+
29
+ def get_config(self):
30
+ return {
31
+ "weight_quantizer": serialization_lib.serialize_keras_object(
32
+ self.weight_quantizer
33
+ ),
34
+ "activation_quantizer": serialization_lib.serialize_keras_object(
35
+ self.activation_quantizer
36
+ ),
37
+ }
38
+
39
+ @classmethod
40
+ def from_config(cls, config):
41
+ weight_quantizer = serialization_lib.deserialize_keras_object(
42
+ config.get("weight_quantizer")
43
+ )
44
+ activation_quantizer = serialization_lib.deserialize_keras_object(
45
+ config.get("activation_quantizer")
46
+ )
47
+ return cls(
48
+ weight_quantizer=weight_quantizer,
49
+ activation_quantizer=activation_quantizer,
50
+ )
51
+
52
+ @staticmethod
53
+ def weight_quantizer_or_default(config, default):
54
+ if config is not None and config.weight_quantizer is not None:
55
+ return config.weight_quantizer
56
+ return default
57
+
58
+ @staticmethod
59
+ def activation_quantizer_or_default(config, default):
60
+ if config is not None:
61
+ return config.activation_quantizer
62
+ return default
63
+
64
+
65
+ @keras_export("keras.quantizers.Int8QuantizationConfig")
66
+ class Int8QuantizationConfig(QuantizationConfig):
67
+ """Int8 quantization config.
68
+
69
+ Args:
70
+ weight_quantizer: Quantizer for weights.
71
+ activation_quantizer: Quantizer for activations. If "default", uses
72
+ AbsMaxQuantizer with axis=-1.
73
+ """
74
+
75
+ def __init__(self, weight_quantizer=None, activation_quantizer="default"):
76
+ from keras.src.quantizers.quantizers import AbsMaxQuantizer
77
+
78
+ if activation_quantizer == "default":
79
+ activation_quantizer = AbsMaxQuantizer()
80
+ super().__init__(weight_quantizer, activation_quantizer)
81
+ if self.weight_quantizer is not None:
82
+ if self.weight_quantizer.output_dtype != "int8":
83
+ raise ValueError(
84
+ "Int8QuantizationConfig requires a weight_quantizer "
85
+ "with output_dtype='int8'. Received: "
86
+ f"output_dtype={self.weight_quantizer.output_dtype}"
87
+ )
88
+
89
+ @property
90
+ def mode(self):
91
+ return "int8"
92
+
93
+
94
+ @keras_export("keras.quantizers.Int4QuantizationConfig")
95
+ class Int4QuantizationConfig(QuantizationConfig):
96
+ """Int4 quantization config.
97
+
98
+ Args:
99
+ weight_quantizer: Quantizer for weights.
100
+ activation_quantizer: Quantizer for activations. If "default", uses
101
+ AbsMaxQuantizer with axis=-1.
102
+ """
103
+
104
+ def __init__(self, weight_quantizer=None, activation_quantizer="default"):
105
+ from keras.src.quantizers.quantizers import AbsMaxQuantizer
106
+
107
+ if activation_quantizer == "default":
108
+ activation_quantizer = AbsMaxQuantizer()
109
+ super().__init__(weight_quantizer, activation_quantizer)
110
+ if self.weight_quantizer is not None:
111
+ if self.weight_quantizer.value_range != (-8, 7):
112
+ raise ValueError(
113
+ "Int4QuantizationConfig requires a weight_quantizer "
114
+ "with value_range=(-8, 7). Received: "
115
+ f"value_range={self.weight_quantizer.value_range}"
116
+ )
117
+
118
+ if self.weight_quantizer.output_dtype != "int8":
119
+ raise ValueError(
120
+ "Int4QuantizationConfig requires a weight_quantizer "
121
+ "with output_dtype='int8'. Received: "
122
+ f"output_dtype={self.weight_quantizer.output_dtype}"
123
+ )
124
+
125
+ @property
126
+ def mode(self):
127
+ return "int4"
128
+
129
+
130
+ @keras_export("keras.quantizers.Float8QuantizationConfig")
131
+ class Float8QuantizationConfig(QuantizationConfig):
132
+ """FP8 quantization config.
133
+
134
+ FP8 mixed-precision training does not support user defined quantizers.
135
+ This config is only used to indicate that FP8 mixed-precision training
136
+ should be used.
137
+ """
138
+
139
+ def __init__(self):
140
+ super().__init__(None, None)
141
+
142
+ @property
143
+ def mode(self):
144
+ return "float8"
145
+
146
+ def get_config(self):
147
+ return {}
148
+
149
+ @classmethod
150
+ def from_config(cls, config):
151
+ return cls()
152
+
153
+
154
+ def validate_and_resolve_config(mode, config):
155
+ """Validate and resolve quantization config.
156
+
157
+ This function validates the quantization config and resolves the mode.
158
+ If mode is not provided, it is inferred from the config.
159
+ If config is not provided, a default config is inferred from the mode.
160
+
161
+ Args:
162
+ mode: Quantization mode.
163
+ config: Quantization config.
164
+ """
165
+ # 1. Backwards Compatibility: Handle string shortcuts.
166
+ if isinstance(config, str):
167
+ mode = config
168
+ config = None
169
+
170
+ _validate_mode(mode)
171
+
172
+ # 2. Resolve "mode" into a Config object.
173
+ if config is None:
174
+ if mode == "int8":
175
+ config = Int8QuantizationConfig()
176
+ elif mode == "int4":
177
+ config = Int4QuantizationConfig()
178
+ elif mode == "float8":
179
+ config = Float8QuantizationConfig()
180
+ elif mode == "gptq":
181
+ raise ValueError(
182
+ "For GPTQ, you must pass a `GPTQConfig` object in the "
183
+ "`config` argument."
184
+ )
185
+ else:
186
+ if mode is not None:
187
+ raise ValueError(
188
+ f"Invalid quantization mode. Received: mode={mode}"
189
+ )
190
+ raise ValueError(
191
+ "You must provide either `mode` or `config` to `quantize`."
192
+ )
193
+ else:
194
+ if not isinstance(config, QuantizationConfig):
195
+ raise ValueError(
196
+ "Argument `config` must be an instance of "
197
+ "`QuantizationConfig`. "
198
+ f"Received: config={config} (of type {type(config)})"
199
+ )
200
+
201
+ # 3. Validation: Prevent contradictions.
202
+ if mode is not None and config.mode != mode:
203
+ raise ValueError(
204
+ f"Contradictory arguments: mode='{mode}' but "
205
+ f"config.mode='{config.mode}'"
206
+ )
207
+
208
+ # Ensure mode is consistent.
209
+ mode = config.mode
210
+
211
+ # Ensure the mode derived from the config is valid.
212
+ _validate_mode(mode)
213
+
214
+ if mode == "gptq":
215
+ from keras.src.quantizers.gptq_config import GPTQConfig
216
+
217
+ if not isinstance(config, GPTQConfig):
218
+ raise ValueError(
219
+ "Mode 'gptq' requires a valid `config` argument of type "
220
+ f"`GPTQConfig`. Received: {type(config)}"
221
+ )
222
+
223
+ return config
224
+
225
+
226
+ def _validate_mode(mode):
227
+ """Validates quantization mode."""
228
+ if mode is not None and mode not in QUANTIZATION_MODES:
229
+ raise ValueError(
230
+ "Invalid quantization mode. "
231
+ f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
232
+ )