keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -48,13 +48,55 @@ class Dropout(Layer):
48
48
  )
49
49
  self.rate = rate
50
50
  self.seed = seed
51
- self.noise_shape = noise_shape
51
+ self.noise_shape = self._validate_noise_shape(noise_shape)
52
52
  if rate > 0:
53
53
  self.seed_generator = backend.random.SeedGenerator(seed)
54
54
  self.supports_masking = True
55
55
 
56
56
  self._build_at_init()
57
57
 
58
+ def _validate_noise_shape(self, noise_shape):
59
+ if noise_shape is None:
60
+ return None
61
+
62
+ if isinstance(noise_shape, str):
63
+ raise ValueError(
64
+ f"Invalid value received for argument `noise_shape`. "
65
+ f"Expected a tuple or list of integers. "
66
+ f"Received: noise_shape={noise_shape}"
67
+ )
68
+
69
+ if not isinstance(noise_shape, tuple):
70
+ try:
71
+ noise_shape = tuple(noise_shape)
72
+ except TypeError:
73
+ raise ValueError(
74
+ f"Invalid value received for argument `noise_shape`. "
75
+ f"Expected an iterable of integers "
76
+ f"(e.g., a tuple or list). "
77
+ f"Received: noise_shape={noise_shape}"
78
+ )
79
+
80
+ for i, dim in enumerate(noise_shape):
81
+ if dim is not None:
82
+ if not isinstance(dim, int):
83
+ raise ValueError(
84
+ f"Invalid value received for argument `noise_shape`. "
85
+ f"Expected all elements to be integers or None. "
86
+ f"Received element at index {i}: {dim} "
87
+ f"(type: {type(dim).__name__})"
88
+ )
89
+
90
+ if dim <= 0:
91
+ raise ValueError(
92
+ f"Invalid value received for argument `noise_shape`. "
93
+ f"Expected all dimensions to be positive integers "
94
+ f"or None. "
95
+ f"Received negative or zero value at index {i}: {dim}"
96
+ )
97
+
98
+ return noise_shape
99
+
58
100
  def call(self, inputs, training=False):
59
101
  if training and self.rate > 0:
60
102
  return backend.random.dropout(
@@ -212,6 +212,7 @@ class RNN(Layer):
212
212
  self.supports_masking = True
213
213
  self.input_spec = None
214
214
  self.states = None
215
+ self._expected_batch_size = None
215
216
 
216
217
  state_size = getattr(self.cell, "state_size", None)
217
218
  if state_size is None:
@@ -283,6 +284,9 @@ class RNN(Layer):
283
284
  f"batch size: sequence.shape={sequences_shape}"
284
285
  )
285
286
  self._create_state_variables(sequences_shape[0])
287
+ self._expected_batch_size = ops.shape(
288
+ tree.flatten(self.states)[0]
289
+ )[0]
286
290
 
287
291
  @tracking.no_automatic_dependency_tracking
288
292
  def _create_state_variables(self, batch_size):
@@ -382,6 +386,21 @@ class RNN(Layer):
382
386
  initial_state = self.get_initial_state(
383
387
  batch_size=ops.shape(sequences)[0]
384
388
  )
389
+ if self.stateful:
390
+ actual_batch_size = sequences.shape[0]
391
+ if (
392
+ self._expected_batch_size is not None
393
+ and actual_batch_size is not None
394
+ and actual_batch_size != self._expected_batch_size
395
+ ):
396
+ raise ValueError(
397
+ f"If an RNN is stateful, the batch size of the "
398
+ f"input sequences must be the same as the batch "
399
+ f"size of the initial state. \n"
400
+ f"- Expected batch size: {self._expected_batch_size}\n"
401
+ f"- Received batch size: {actual_batch_size}"
402
+ )
403
+
385
404
  # RNN expect the states in a list, even if single state.
386
405
  if not tree.is_nested(initial_state):
387
406
  initial_state = [initial_state]
keras/src/losses/loss.py CHANGED
@@ -211,7 +211,7 @@ def apply_mask(sample_weight, mask, dtype, reduction):
211
211
  dtype,
212
212
  )
213
213
  valid = ops.sum(mask) # May be 0!
214
- mask *= total / (valid + backend.epsilon())
214
+ mask *= ops.divide_no_nan(total, valid)
215
215
 
216
216
  if sample_weight is not None:
217
217
  sample_weight = ops.cast(sample_weight, dtype=dtype)
@@ -73,6 +73,14 @@ class MeanSquaredError(LossFunctionWrapper):
73
73
  `"float32"` unless set to different value
74
74
  (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
75
75
  provided, then the `compute_dtype` will be utilized.
76
+
77
+ Examples:
78
+
79
+ >>> y_true = keras.ops.array([1.0, 0.0, 1.0])
80
+ >>> y_pred = keras.ops.array([0.9, 0.1, 0.8])
81
+ >>> loss = keras.losses.MeanSquaredError()
82
+ >>> loss(y_true, y_pred)
83
+ 0.02
76
84
  """
77
85
 
78
86
  def __init__(
@@ -114,6 +122,14 @@ class MeanAbsoluteError(LossFunctionWrapper):
114
122
  `"float32"` unless set to different value
115
123
  (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
116
124
  provided, then the `compute_dtype` will be utilized.
125
+
126
+ Examples:
127
+
128
+ >>> y_true = keras.ops.array([1.0, 0.3, 1.0])
129
+ >>> y_pred = keras.ops.array([1.9, 0.3, 1.8])
130
+ >>> loss = keras.losses.MeanAbsoluteError()
131
+ >>> loss(y_true, y_pred)
132
+ 0.5666667
117
133
  """
118
134
 
119
135
  def __init__(
@@ -155,6 +171,14 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
155
171
  `"float32"` unless set to different value
156
172
  (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
157
173
  provided, then the `compute_dtype` will be utilized.
174
+
175
+ Examples:
176
+
177
+ >>> y_true = keras.ops.array([100.0, 200.0, 300.0])
178
+ >>> y_pred = keras.ops.array([90.0, 210.0, 310.0])
179
+ >>> loss = keras.losses.MeanAbsolutePercentageError()
180
+ >>> loss(y_true, y_pred)
181
+ 6.111111
158
182
  """
159
183
 
160
184
  def __init__(
@@ -654,7 +654,7 @@ class SensitivitySpecificityBase(Metric):
654
654
  Args:
655
655
  constrained: Over these values the constraint is specified. A rank-1
656
656
  tensor.
657
- dependent: From these values the maximum that satiesfies the
657
+ dependent: From these values the maximum that satisfies the
658
658
  constraint is selected. Values in this tensor and in
659
659
  `constrained` are linked by having the same threshold at each
660
660
  position, hence this tensor must have the same shape.
@@ -664,11 +664,12 @@ class SensitivitySpecificityBase(Metric):
664
664
  Returns:
665
665
  maximal dependent value, if no value satisfies the constraint 0.0.
666
666
  """
667
- feasible = ops.nonzero(predicate(constrained, self.value))
668
- feasible_exists = ops.greater(ops.size(feasible), 0)
669
- max_dependent = ops.max(ops.take(dependent, feasible), initial=0)
670
-
671
- return ops.where(feasible_exists, max_dependent, 0.0)
667
+ feasible = predicate(constrained, self.value)
668
+ # Mask values based on whether they satisfy the constraint and take max.
669
+ return ops.max(
670
+ ops.multiply(dependent, ops.cast(feasible, dependent.dtype)),
671
+ initial=0,
672
+ )
672
673
 
673
674
 
674
675
  @keras_export("keras.metrics.SensitivityAtSpecificity")
@@ -293,10 +293,12 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
293
293
  input_name = ref_input_layer.name
294
294
  input_batch_shape = ref_input_layer.batch_shape
295
295
  input_dtype = ref_input_layer._dtype
296
+ input_optional = ref_input_layer.optional
296
297
  else:
297
298
  input_name = None
298
299
  input_dtype = None
299
300
  input_batch_shape = None
301
+ input_optional = False
300
302
 
301
303
  if input_tensors is not None:
302
304
  if isinstance(input_tensors, (list, tuple)):
@@ -313,6 +315,7 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
313
315
  inputs = Input(
314
316
  tensor=input_tensors,
315
317
  name=input_name,
318
+ optional=input_optional,
316
319
  )
317
320
  new_layers = [inputs] + new_layers
318
321
  else:
@@ -321,6 +324,7 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
321
324
  batch_shape=input_batch_shape,
322
325
  dtype=input_dtype,
323
326
  name=input_name,
327
+ optional=input_optional,
324
328
  )
325
329
  new_layers = [inputs] + new_layers
326
330
  cloned_model = Sequential(
@@ -254,9 +254,9 @@ class Functional(Function, Model):
254
254
  return converted
255
255
 
256
256
  def _adjust_input_rank(self, flat_inputs):
257
- flat_ref_shapes = [x.shape for x in self._inputs]
258
257
  adjusted = []
259
- for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
258
+ for i, x in enumerate(flat_inputs):
259
+ ref_shape = self._inputs[i].shape
260
260
  if x is None:
261
261
  adjusted.append(x)
262
262
  continue
@@ -273,8 +273,11 @@ class Functional(Function, Model):
273
273
  if ref_shape[-1] == 1:
274
274
  adjusted.append(ops.expand_dims(x, axis=-1))
275
275
  continue
276
+ flat_paths_and_inputs = tree.flatten_with_path(self._inputs_struct)
277
+ path = ".".join(str(p) for p in flat_paths_and_inputs[i][0])
276
278
  raise ValueError(
277
- f"Invalid input shape for input {x}. Expected shape "
279
+ f"Invalid input shape for input {x} with name "
280
+ f"'{self._inputs[i].name}' and path '{path}'. Expected shape "
278
281
  f"{ref_shape}, but input has incompatible shape {x.shape}"
279
282
  )
280
283
  # Add back metadata.
@@ -832,11 +835,16 @@ def clone_graph_nodes(inputs, outputs):
832
835
  kt_id_mapping[id(kt_input)] = kt_input
833
836
  else:
834
837
  # We need to create a new Keras tensor for any intermediate tensor
838
+ original_op = kt_input._keras_history.operation
839
+ optional = False
840
+ if isinstance(original_op, InputLayer):
841
+ optional = original_op.optional
835
842
  cloned_input = Input(
836
843
  batch_shape=kt_input.shape,
837
844
  dtype=kt_input.dtype,
838
845
  sparse=kt_input.sparse,
839
846
  name=f"{kt_input.name}CLONE",
847
+ optional=optional,
840
848
  )
841
849
  cloned_inputs.append(cloned_input)
842
850
  kt_id_mapping[id(kt_input)] = cloned_input
keras/src/models/model.py CHANGED
@@ -2,14 +2,16 @@ import inspect
2
2
  import json
3
3
  import typing
4
4
  import warnings
5
+ from collections.abc import Callable
5
6
 
6
7
  from keras.src import backend
7
8
  from keras.src import utils
8
9
  from keras.src.api_export import keras_export
9
10
  from keras.src.layers.layer import Layer
10
11
  from keras.src.models.variable_mapping import map_saveable_variables
11
- from keras.src.quantizers.gptq_config import GPTQConfig
12
+ from keras.src.quantizers.awq_core import awq_quantize
12
13
  from keras.src.quantizers.gptq_core import gptq_quantize
14
+ from keras.src.quantizers.utils import should_quantize_layer
13
15
  from keras.src.saving import saving_api
14
16
  from keras.src.trainers import trainer as base_trainer
15
17
  from keras.src.utils import summary_utils
@@ -422,19 +424,99 @@ class Model(Trainer, base_trainer.Trainer, Layer):
422
424
  **kwargs,
423
425
  )
424
426
 
425
- def quantize(self, mode, config=None, **kwargs):
427
+ def get_quantization_layer_structure(self, mode=None):
428
+ """Returns the quantization structure for the model.
429
+
430
+ This method is intended to be overridden by model authors to provide
431
+ topology information required for structure-aware quantization modes
432
+ like 'gptq'.
433
+
434
+ Args:
435
+ mode: The quantization mode.
436
+
437
+ Returns:
438
+ A dictionary describing the topology, e.g.:
439
+ `{'pre_block_layers': [list], 'sequential_blocks': [list]}`
440
+ or `None` if the mode does not require structure or is not
441
+ supported. `'pre_block_layers'` is a list of layers that
442
+ the inputs should be passed through, before being passed to
443
+ the sequential blocks. For example, inputs to an LLM must
444
+ first be passed through an embedding layer, followed by
445
+ the transformer.
446
+ """
447
+ del mode # Unused.
448
+ return None
449
+
450
+ def quantize(self, mode=None, config=None, filters=None, **kwargs):
426
451
  """Quantize the weights of the model.
427
452
 
428
453
  Note that the model must be built first before calling this method.
429
- `quantize` will recursively call `quantize(mode)` in all layers and
454
+ `quantize` will recursively call `quantize(...)` in all layers and
430
455
  will be skipped if the layer doesn't implement the function.
431
456
 
457
+ This method can be called by passing a `mode` string, which uses the
458
+ default configuration for that mode. Alternatively, a `config` object
459
+ can be passed to customize the behavior of the quantization (e.g. to
460
+ use specific quantizers for weights or activations).
461
+
432
462
  Args:
433
- mode: The mode of the quantization. Only 'int8' is supported at this
434
- time.
435
- """
436
- from keras.src.dtype_policies import QUANTIZATION_MODES
463
+ mode: The mode of the quantization. Supported modes are:
464
+ `"int8"`, `"int4"`, `"float8"`, `"gptq"`. This is
465
+ optional if `config` is provided.
466
+ config: The configuration object specifying additional
467
+ quantization options. This argument allows to configure
468
+ the weight and activation quantizers. be an instance of
469
+ `keras.quantizers.QuantizationConfig`.
470
+ filters: Optional filters to apply to the quantization. Can be a
471
+ regex string, a list of regex strings, or a callable. Only the
472
+ layers which match the filter conditions will be quantized.
473
+ **kwargs: Additional keyword arguments.
474
+
475
+ Example:
476
+
477
+ Quantize a model to int8 with default configuration:
437
478
 
479
+ ```python
480
+ # Build the model
481
+ model = keras.Sequential([
482
+ keras.Input(shape=(10,)),
483
+ keras.layers.Dense(10),
484
+ ])
485
+ model.build((None, 10))
486
+
487
+ # Quantize with default int8 config
488
+ model.quantize("int8")
489
+ ```
490
+
491
+ Quantize a model to int8 with a custom configuration:
492
+
493
+ ```python
494
+ from keras.quantizers import Int8QuantizationConfig
495
+ from keras.quantizers import AbsMaxQuantizer
496
+
497
+ # Build the model
498
+ model = keras.Sequential([
499
+ keras.Input(shape=(10,)),
500
+ keras.layers.Dense(10),
501
+ ])
502
+ model.build((None, 10))
503
+
504
+ # Create a custom config
505
+ config = Int8QuantizationConfig(
506
+ weight_quantizer=AbsMaxQuantizer(
507
+ axis=0,
508
+ value_range=(-127, 127)
509
+ ),
510
+ activation_quantizer=AbsMaxQuantizer(
511
+ axis=-1,
512
+ value_range=(-127, 127)
513
+ ),
514
+ )
515
+
516
+ # Quantize with custom config
517
+ model.quantize(config=config)
518
+ ```
519
+ """
438
520
  # Validate inputs.
439
521
  type_check = kwargs.pop("type_check", True)
440
522
  if kwargs:
@@ -443,27 +525,20 @@ class Model(Trainer, base_trainer.Trainer, Layer):
443
525
  f"passed to {self.__class__.__name__}: {kwargs}"
444
526
  )
445
527
 
446
- if mode not in QUANTIZATION_MODES:
447
- raise ValueError(
448
- "Invalid quantization mode. "
449
- f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
450
- )
451
-
452
- if mode == "gptq":
453
- if not isinstance(config, GPTQConfig):
528
+ if filters is not None:
529
+ if not isinstance(filters, (str, Callable, list, tuple)):
454
530
  raise ValueError(
455
- "Mode 'gptq' requires a valid `config` argument of type "
456
- f"`GPTQConfig`. Received: {type(config)}"
531
+ "The `filters` argument must be a regex string, a list of "
532
+ "regex strings, or a callable. Received: "
533
+ f"{type(filters)}"
457
534
  )
458
- elif config is not None:
459
- # All other modes must not receive a config
460
- raise ValueError(
461
- f"The `config` argument is only supported for 'gptq' mode, "
462
- f"but received mode='{mode}' and a non-None config."
463
- )
464
535
 
465
536
  graph_modified = False
466
537
  for layer in self._flatten_layers():
538
+ # Apply filters
539
+ if not should_quantize_layer(layer, filters):
540
+ continue
541
+
467
542
  if len(list(layer._flatten_layers())) == 1:
468
543
  try:
469
544
  layer.quantize(mode, type_check=type_check, config=config)
@@ -473,8 +548,29 @@ class Model(Trainer, base_trainer.Trainer, Layer):
473
548
  except AttributeError:
474
549
  pass
475
550
 
476
- if mode == "gptq":
477
- gptq_quantize(self, config)
551
+ if mode in ["gptq", "awq"]:
552
+ # Resolve model structure.
553
+ # 1. If quantization_layer_structure is provided inside the config,
554
+ # use that.
555
+ structure = config.quantization_layer_structure
556
+ # 2. If no layer structure is provided in the config, try to fetch
557
+ # it using the `get_quantization_layer_structure` hook.
558
+ if structure is None:
559
+ structure = self.get_quantization_layer_structure(mode)
560
+
561
+ if structure is None:
562
+ raise ValueError(
563
+ f"For {mode=}, a valid quantization structure must be "
564
+ "provided either via `config.quantization_layer_structure` "
565
+ "or by overriding "
566
+ "`model.get_quantization_layer_structure(mode)`. The "
567
+ "structure should be a dictionary with keys "
568
+ "'pre_block_layers' and 'sequential_blocks'."
569
+ )
570
+ if mode == "gptq":
571
+ gptq_quantize(config, structure, filters=filters)
572
+ elif mode == "awq":
573
+ awq_quantize(config, structure, filters=filters)
478
574
 
479
575
  # If any layer was changed, we must rebuild the execution functions.
480
576
  if graph_modified:
@@ -569,8 +665,8 @@ class Model(Trainer, base_trainer.Trainer, Layer):
569
665
  filepath: `str` or `pathlib.Path` object. The path to save the
570
666
  artifact.
571
667
  format: `str`. The export format. Supported values:
572
- `"tf_saved_model"` and `"onnx"`. Defaults to
573
- `"tf_saved_model"`.
668
+ `"tf_saved_model"`, `"onnx"`, `"openvino"`, and `"litert"`.
669
+ Defaults to `"tf_saved_model"`.
574
670
  verbose: `bool`. Whether to print a message during export. Defaults
575
671
  to `None`, which uses the default value set by different
576
672
  backends and formats.
@@ -593,6 +689,13 @@ class Model(Trainer, base_trainer.Trainer, Layer):
593
689
  provided, they will be automatically computed.
594
690
  - `opset_version`: Optional `int`. Specific to `format="onnx"`.
595
691
  An integer value that specifies the ONNX opset version.
692
+ - LiteRT-specific options: Optional keyword arguments specific
693
+ to `format="litert"`. These are passed directly to the
694
+ TensorFlow Lite converter and include options like
695
+ `optimizations`, `representative_dataset`,
696
+ `experimental_new_quantizer`, `allow_custom_ops`,
697
+ `enable_select_tf_ops`, etc. See TensorFlow Lite
698
+ documentation for all available options.
596
699
 
597
700
  **Note:** This feature is currently supported only with TensorFlow, JAX
598
701
  and Torch backends.
@@ -627,18 +730,41 @@ class Model(Trainer, base_trainer.Trainer, Layer):
627
730
  }
628
731
  predictions = ort_session.run(None, ort_inputs)
629
732
  ```
733
+
734
+ Here's how to export a LiteRT (TFLite) for inference.
735
+
736
+ ```python
737
+ # Export the model as a LiteRT artifact
738
+ model.export("path/to/location", format="litert")
739
+
740
+ # Load the artifact in a different process/environment
741
+ interpreter = tf.lite.Interpreter(model_path="path/to/location")
742
+ interpreter.allocate_tensors()
743
+ interpreter.set_tensor(
744
+ interpreter.get_input_details()[0]['index'], input_data
745
+ )
746
+ interpreter.invoke()
747
+ output_data = interpreter.get_tensor(
748
+ interpreter.get_output_details()[0]['index']
749
+ )
750
+ ```
630
751
  """
752
+ from keras.src.export import export_litert
631
753
  from keras.src.export import export_onnx
632
754
  from keras.src.export import export_openvino
633
755
  from keras.src.export import export_saved_model
634
756
 
635
- available_formats = ("tf_saved_model", "onnx", "openvino")
757
+ available_formats = ("tf_saved_model", "onnx", "openvino", "litert")
636
758
  if format not in available_formats:
637
759
  raise ValueError(
638
760
  f"Unrecognized format={format}. Supported formats are: "
639
761
  f"{list(available_formats)}."
640
762
  )
641
763
 
764
+ # Check if LiteRT export is available (requires TensorFlow backend)
765
+ if format == "litert" and backend.backend() != "tensorflow":
766
+ raise ImportError("LiteRT export requires TensorFlow backend.")
767
+
642
768
  if format == "tf_saved_model":
643
769
  export_saved_model(
644
770
  self,
@@ -663,6 +789,13 @@ class Model(Trainer, base_trainer.Trainer, Layer):
663
789
  input_signature=input_signature,
664
790
  **kwargs,
665
791
  )
792
+ elif format == "litert":
793
+ export_litert(
794
+ self,
795
+ filepath,
796
+ input_signature=input_signature,
797
+ **kwargs,
798
+ )
666
799
 
667
800
  @classmethod
668
801
  def from_config(cls, config, custom_objects=None):
@@ -863,13 +996,18 @@ class Model(Trainer, base_trainer.Trainer, Layer):
863
996
  self.non_trainable_variables, path_value_dict
864
997
  )
865
998
  elif k == "optimizer_variables":
866
- self._assign_variable_values(
867
- self.optimizer.variables, path_value_dict
868
- )
999
+ if hasattr(self, "optimizer") and self.optimizer is not None:
1000
+ self._assign_variable_values(
1001
+ self.optimizer.variables, path_value_dict
1002
+ )
869
1003
  elif k == "metrics_variables":
870
- self._assign_variable_values(
871
- self.metrics_variables, path_value_dict
872
- )
1004
+ if (
1005
+ hasattr(self, "metrics_variables")
1006
+ and self.metrics_variables
1007
+ ):
1008
+ self._assign_variable_values(
1009
+ self.metrics_variables, path_value_dict
1010
+ )
873
1011
  else:
874
1012
  raise ValueError(f"Unknown variable name: {k}")
875
1013