keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ from keras.src import ops
3
3
  from keras.src.api_export import keras_export
4
4
  from keras.src.backend.common import global_state
5
5
 
6
- QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
6
+ QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq", "awq")
7
7
 
8
8
 
9
9
  @keras_export(
@@ -288,6 +288,181 @@ class QuantizedFloat8DTypePolicy(QuantizedDTypePolicy):
288
288
  return config
289
289
 
290
290
 
291
+ @keras_export("keras.dtype_policies.GPTQDTypePolicy")
292
+ class GPTQDTypePolicy(QuantizedDTypePolicy):
293
+ """Quantized dtype policy for GPTQ quantization.
294
+
295
+ This policy helps propagate quantization settings for GPTQ
296
+ when loading a GPTQ quantized model in Keras format.
297
+
298
+ Args:
299
+ mode: The quantization mode. This should be a string in the format
300
+ `"gptq/<weight_bits>/<group_size>"`.
301
+ - `"gptq"`: The identifier for the quantization algorithm.
302
+ - `<weight_bits>`: Number of bits to quantize weights to.
303
+ Supported values are 2, 3, 4, and 8.
304
+ - `<group_size>`: The group size for quantization. Supported
305
+ values are -1 (for whole-tensor quantization) or any
306
+ positive integer. Typically a smaller group size leads
307
+ to better accuracy but slower speed.
308
+ Example: `"gptq/4/128"`.
309
+ source_name: The source dtype policy name, e.g. "float32".
310
+ """
311
+
312
+ def __init__(
313
+ self,
314
+ mode,
315
+ source_name=None,
316
+ ):
317
+ parts = mode.split("/")
318
+ expected_format = "'gptq/<weight_bits>/<group_size>'"
319
+
320
+ # Validate format
321
+ if len(parts) != 3 or parts[0] != "gptq":
322
+ raise ValueError(
323
+ "Invalid mode for GPTQDTypePolicy. Expected format "
324
+ f"{expected_format}, but got '{mode}'."
325
+ )
326
+
327
+ # Validate and cast weight_bits and group_size
328
+ try:
329
+ weight_bits = int(parts[1])
330
+ group_size = int(parts[2])
331
+ except ValueError:
332
+ raise ValueError(
333
+ "Invalid mode for GPTQDTypePolicy. <weight_bits> and "
334
+ "<group_size> must be integers. Expected format "
335
+ f"{expected_format}, but got '{mode}'."
336
+ )
337
+
338
+ # Validate supported values
339
+ if weight_bits not in [2, 3, 4, 8]:
340
+ raise ValueError(
341
+ "Invalid weight_bits in mode. Supported values are "
342
+ f"2, 3, 4, and 8, but got {weight_bits} from '{mode}'."
343
+ )
344
+
345
+ if group_size < -1 or group_size == 0:
346
+ raise ValueError(
347
+ "Invalid group_size in mode. Supported values are "
348
+ "-1 (whole-tensor) or a positive integer, "
349
+ f"but got {group_size} from '{mode}'."
350
+ )
351
+
352
+ base_mode = parts[0]
353
+ super().__init__(
354
+ mode=base_mode,
355
+ source_name=source_name,
356
+ )
357
+
358
+ self._name = f"{mode}_from_{source_name}"
359
+ self.mode = base_mode
360
+ self.weight_bits = weight_bits
361
+ self.group_size = group_size
362
+
363
+ def __eq__(self, other):
364
+ if super().__eq__(other) is False:
365
+ return False
366
+ return (
367
+ self.weight_bits == other.weight_bits
368
+ and self.group_size == other.group_size
369
+ )
370
+
371
+ def get_config(self):
372
+ config = super().get_config()
373
+ # Reconstruct the full mode string for serialization
374
+ mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
375
+ config.update({"mode": mode})
376
+ return config
377
+
378
+
379
+ @keras_export("keras.dtype_policies.AWQDTypePolicy")
380
+ class AWQDTypePolicy(QuantizedDTypePolicy):
381
+ """Quantized dtype policy for AWQ quantization.
382
+
383
+ This policy helps propagate quantization settings for AWQ
384
+ when loading an AWQ quantized model in Keras format.
385
+
386
+ Args:
387
+ mode: The quantization mode. This should be a string in the format
388
+ `"awq/<weight_bits>/<group_size>"`.
389
+ - `"awq"`: The identifier for the quantization algorithm.
390
+ - `<weight_bits>`: Number of bits to quantize weights to.
391
+ AWQ presently only supports 4-bit quantization.
392
+ - `<group_size>`: The group size for quantization. Supported
393
+ values are -1 (for per-channel quantization) or any
394
+ positive integer.
395
+ Example: `"awq/4/128"`.
396
+ source_name: The source dtype policy name, e.g. "float32".
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ mode,
402
+ source_name=None,
403
+ ):
404
+ parts = mode.split("/")
405
+ expected_format = "'awq/<weight_bits>/<group_size>'"
406
+
407
+ # Validate format.
408
+ if len(parts) != 3 or parts[0] != "awq":
409
+ raise ValueError(
410
+ "Invalid mode for AWQDTypePolicy. Expected format "
411
+ f"{expected_format}, but got '{mode}'."
412
+ )
413
+
414
+ # Validate and cast weight_bits and group_size.
415
+ try:
416
+ weight_bits = int(parts[1])
417
+ group_size = int(parts[2])
418
+ except ValueError:
419
+ raise ValueError(
420
+ "Invalid mode for AWQDTypePolicy. <weight_bits> and "
421
+ "<group_size> must be integers. Expected format "
422
+ f"{expected_format}, but got '{mode}'."
423
+ )
424
+
425
+ # AWQ presently only supports 4-bit quantization.
426
+ if weight_bits != 4:
427
+ raise ValueError(
428
+ "Invalid weight_bits in mode. AWQ only supports 4-bit "
429
+ f"quantization, but got {weight_bits} from '{mode}'."
430
+ )
431
+
432
+ if group_size < -1 or group_size == 0:
433
+ raise ValueError(
434
+ "Invalid group_size in mode. Supported values are "
435
+ "-1 (per-channel) or a positive integer, "
436
+ f"but got {group_size} from '{mode}'."
437
+ )
438
+
439
+ base_mode = parts[0]
440
+ super().__init__(
441
+ mode=base_mode,
442
+ source_name=source_name,
443
+ )
444
+
445
+ self._name = f"{mode}_from_{source_name}"
446
+ self.mode = base_mode
447
+ self.weight_bits = weight_bits
448
+ self.group_size = group_size
449
+
450
+ def __eq__(self, other):
451
+ if super().__eq__(other) is False:
452
+ return False
453
+ return (
454
+ self.weight_bits == other.weight_bits
455
+ and self.group_size == other.group_size
456
+ )
457
+
458
+ def get_config(self):
459
+ config = super().get_config()
460
+ # Reconstruct the full mode string for serialization
461
+ mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
462
+ config.update({"mode": mode})
463
+ return config
464
+
465
+
291
466
  @keras_export(
292
467
  [
293
468
  "keras.config.set_dtype_policy",
@@ -352,6 +527,10 @@ def _get_quantized_dtype_policy_by_str(policy):
352
527
  mode, source_name = split_name
353
528
  if policy.startswith("int8") or policy.startswith("int4"):
354
529
  return QuantizedDTypePolicy(mode, source_name)
530
+ elif policy.startswith("gptq"):
531
+ return GPTQDTypePolicy(mode, source_name)
532
+ elif policy.startswith("awq"):
533
+ return AWQDTypePolicy(mode, source_name)
355
534
  elif policy.startswith("float8"):
356
535
  return QuantizedFloat8DTypePolicy(mode, source_name)
357
536
  else:
@@ -1,3 +1,5 @@
1
+ from keras.src.export.litert import LiteRTExporter
2
+ from keras.src.export.litert import export_litert
1
3
  from keras.src.export.onnx import export_onnx
2
4
  from keras.src.export.openvino import export_openvino
3
5
  from keras.src.export.saved_model import ExportArchive
@@ -7,6 +7,14 @@ from keras.src.utils.module_utils import tensorflow as tf
7
7
 
8
8
 
9
9
  def get_input_signature(model):
10
+ """Get input signature for model export.
11
+
12
+ Args:
13
+ model: A Keras Model instance.
14
+
15
+ Returns:
16
+ Input signature suitable for model export (always a tuple or list).
17
+ """
10
18
  if not isinstance(model, models.Model):
11
19
  raise TypeError(
12
20
  "The model must be a `keras.Model`. "
@@ -17,13 +25,20 @@ def get_input_signature(model):
17
25
  "The model provided has not yet been built. It must be built "
18
26
  "before export."
19
27
  )
28
+
20
29
  if isinstance(model, models.Functional):
30
+ # Functional models expect a single positional argument `inputs`
31
+ # containing the full nested input structure. We keep the
32
+ # original behavior of returning a single-element list that
33
+ # wraps the mapped structure so that downstream exporters
34
+ # build a tf.function with one positional argument.
21
35
  input_signature = [
22
36
  tree.map_structure(make_input_spec, model._inputs_struct)
23
37
  ]
24
38
  elif isinstance(model, models.Sequential):
25
39
  input_signature = tree.map_structure(make_input_spec, model.inputs)
26
40
  else:
41
+ # Subclassed models: rely on recorded shapes from the first call.
27
42
  input_signature = _infer_input_signature_from_model(model)
28
43
  if not input_signature or not model._called:
29
44
  raise ValueError(
@@ -60,6 +75,7 @@ def _infer_input_signature_from_model(model):
60
75
  f"Unsupported type {type(structure)} for {structure}"
61
76
  )
62
77
 
78
+ # Always return a flat list preserving the order of shapes_dict values
63
79
  return [_make_input_spec(value) for value in shapes_dict.values()]
64
80
 
65
81
 
@@ -86,13 +102,34 @@ def make_input_spec(x):
86
102
  return input_spec
87
103
 
88
104
 
89
- def make_tf_tensor_spec(x):
105
+ def make_tf_tensor_spec(x, dynamic_batch=False):
106
+ """Create a TensorSpec from various input types.
107
+
108
+ Args:
109
+ x: Input to convert (tf.TensorSpec, KerasTensor, or backend tensor).
110
+ dynamic_batch: If True, set the batch dimension to None.
111
+
112
+ Returns:
113
+ A tf.TensorSpec instance.
114
+ """
90
115
  if isinstance(x, tf.TensorSpec):
91
116
  tensor_spec = x
117
+ # Adjust batch dimension if needed
118
+ if dynamic_batch and len(tensor_spec.shape) > 0:
119
+ shape = tuple(
120
+ None if i == 0 else s for i, s in enumerate(tensor_spec.shape)
121
+ )
122
+ tensor_spec = tf.TensorSpec(
123
+ shape, dtype=tensor_spec.dtype, name=tensor_spec.name
124
+ )
92
125
  else:
93
126
  input_spec = make_input_spec(x)
127
+ shape = input_spec.shape
128
+ # Adjust batch dimension if needed and shape is not None
129
+ if dynamic_batch and shape is not None and len(shape) > 0:
130
+ shape = tuple(None if i == 0 else s for i, s in enumerate(shape))
94
131
  tensor_spec = tf.TensorSpec(
95
- input_spec.shape, dtype=input_spec.dtype, name=input_spec.name
132
+ shape, dtype=input_spec.dtype, name=input_spec.name
96
133
  )
97
134
  return tensor_spec
98
135
 
@@ -0,0 +1,248 @@
1
+ from keras.src import layers
2
+ from keras.src import models
3
+ from keras.src import tree
4
+ from keras.src.export.export_utils import get_input_signature
5
+ from keras.src.utils import io_utils
6
+ from keras.src.utils.module_utils import tensorflow as tf
7
+
8
+
9
+ def export_litert(
10
+ model,
11
+ filepath,
12
+ input_signature=None,
13
+ **kwargs,
14
+ ):
15
+ """Export the model as a LiteRT artifact for inference.
16
+
17
+ Args:
18
+ model: The Keras model to export.
19
+ filepath: The path to save the exported artifact.
20
+ input_signature: Optional input signature specification. If
21
+ `None`, it will be inferred.
22
+ **kwargs: Additional keyword arguments passed to the exporter.
23
+ """
24
+
25
+ exporter = LiteRTExporter(
26
+ model=model,
27
+ input_signature=input_signature,
28
+ **kwargs,
29
+ )
30
+ exporter.export(filepath)
31
+ io_utils.print_msg(f"Saved artifact at '{filepath}'.")
32
+
33
+
34
+ class LiteRTExporter:
35
+ """Exporter for the LiteRT (TFLite) format.
36
+
37
+ This class handles the conversion of Keras models for LiteRT runtime and
38
+ generates a `.tflite` model file. For efficient inference on mobile and
39
+ embedded devices, it creates a single callable signature based on the
40
+ model's `call()` method.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ model,
46
+ input_signature=None,
47
+ **kwargs,
48
+ ):
49
+ """Initialize the LiteRT exporter.
50
+
51
+ Args:
52
+ model: The Keras model to export
53
+ input_signature: Input signature specification (e.g., TensorFlow
54
+ TensorSpec or list of TensorSpec)
55
+ **kwargs: Additional export parameters
56
+ """
57
+ self.model = model
58
+ self.input_signature = input_signature
59
+ self.kwargs = kwargs
60
+
61
+ def export(self, filepath):
62
+ """Exports the Keras model to a TFLite file.
63
+
64
+ Args:
65
+ filepath: Output path for the exported model
66
+
67
+ Returns:
68
+ Path to exported model
69
+ """
70
+ # 1. Resolve / infer input signature
71
+ if self.input_signature is None:
72
+ # Use the standard get_input_signature which handles all model types
73
+ # and preserves nested structures (dicts, lists, etc.)
74
+ self.input_signature = get_input_signature(self.model)
75
+
76
+ # 2. Determine input structure and create adapter if needed
77
+ # There are 3 cases:
78
+ # Case 1: Single input (not nested)
79
+ # Case 2: Flat list of inputs (list where flattened == original)
80
+ # Case 3: Nested structure (dicts, nested lists, etc.)
81
+
82
+ # Special handling for Functional models: get_input_signature wraps
83
+ # the structure in a list, so unwrap it for analysis
84
+ input_struct = self.input_signature
85
+ if (
86
+ isinstance(self.input_signature, list)
87
+ and len(self.input_signature) == 1
88
+ ):
89
+ input_struct = self.input_signature[0]
90
+
91
+ if not tree.is_nested(input_struct):
92
+ # Case 1: Single input - use as-is
93
+ model_to_convert = self.model
94
+ signature_for_conversion = self.input_signature
95
+ elif isinstance(input_struct, list) and len(input_struct) == len(
96
+ tree.flatten(input_struct)
97
+ ):
98
+ # Case 2: Flat list of inputs - use as-is
99
+ model_to_convert = self.model
100
+ signature_for_conversion = self.input_signature
101
+ else:
102
+ # Case 3: Nested structure (dict, nested lists, etc.)
103
+ # Create adapter model that converts flat list to nested structure
104
+ adapted_model = self._create_nested_inputs_adapter(input_struct)
105
+
106
+ # Flatten signature for TFLite conversion
107
+ signature_for_conversion = tree.flatten(input_struct)
108
+
109
+ # Use adapted model and flat list signature for conversion
110
+ model_to_convert = adapted_model
111
+
112
+ # Store original model reference for later use
113
+ original_model = self.model
114
+
115
+ # Temporarily replace self.model with the model to convert
116
+ self.model = model_to_convert
117
+
118
+ try:
119
+ # Convert the model to TFLite.
120
+ tflite_model = self._convert_to_tflite(signature_for_conversion)
121
+ finally:
122
+ # Restore original model
123
+ self.model = original_model
124
+
125
+ # Save the TFLite model to the specified file path.
126
+ if not filepath.endswith(".tflite"):
127
+ raise ValueError(
128
+ f"The LiteRT export requires the filepath to end with "
129
+ f"'.tflite'. Got: {filepath}"
130
+ )
131
+
132
+ with open(filepath, "wb") as f:
133
+ f.write(tflite_model)
134
+
135
+ return filepath
136
+
137
+ def _create_nested_inputs_adapter(self, input_signature_struct):
138
+ """Create an adapter model that converts flat list inputs to nested
139
+ structure.
140
+
141
+ This adapter allows models expecting nested inputs (dicts, lists, etc.)
142
+ to be exported to TFLite format (which only supports positional/list
143
+ inputs).
144
+
145
+ Args:
146
+ input_signature_struct: Nested structure of InputSpecs (dict, list,
147
+ etc.)
148
+
149
+ Returns:
150
+ A Functional model that accepts flat list inputs and converts to
151
+ nested
152
+ """
153
+ # Get flat paths to preserve names and print input mapping
154
+ paths_and_specs = tree.flatten_with_path(input_signature_struct)
155
+ paths = [".".join(str(e) for e in p) for p, v in paths_and_specs]
156
+ io_utils.print_msg(f"Creating adapter for inputs: {paths}")
157
+
158
+ # Create Input layers for TFLite (flat list-based)
159
+ input_layers = []
160
+ for path, spec in paths_and_specs:
161
+ # Extract the input name from spec or path
162
+ name = (
163
+ spec.name
164
+ if hasattr(spec, "name") and spec.name
165
+ else (str(path[-1]) if path else "input")
166
+ )
167
+
168
+ input_layer = layers.Input(
169
+ shape=spec.shape[1:], # Remove batch dimension
170
+ dtype=spec.dtype,
171
+ name=name,
172
+ )
173
+ input_layers.append(input_layer)
174
+
175
+ # Reconstruct the nested structure from flat list
176
+ inputs_structure = tree.pack_sequence_as(
177
+ input_signature_struct, input_layers
178
+ )
179
+
180
+ # Call the original model with nested inputs
181
+ outputs = self.model(inputs_structure)
182
+
183
+ # Build as Functional model (flat list inputs -> nested -> model ->
184
+ # output)
185
+ adapted_model = models.Model(inputs=input_layers, outputs=outputs)
186
+
187
+ # Preserve the original model's variables
188
+ adapted_model._variables = self.model.variables
189
+ adapted_model._trainable_variables = self.model.trainable_variables
190
+ adapted_model._non_trainable_variables = (
191
+ self.model.non_trainable_variables
192
+ )
193
+
194
+ return adapted_model
195
+
196
+ def _convert_to_tflite(self, input_signature):
197
+ """Converts the Keras model to TFLite format.
198
+
199
+ Returns:
200
+ A bytes object containing the serialized TFLite model.
201
+ """
202
+ # Try direct conversion first for all models
203
+ try:
204
+ converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
205
+ converter.target_spec.supported_ops = [
206
+ tf.lite.OpsSet.TFLITE_BUILTINS,
207
+ tf.lite.OpsSet.SELECT_TF_OPS,
208
+ ]
209
+ # Keras 3 only supports resource variables
210
+ converter.experimental_enable_resource_variables = True
211
+
212
+ # Apply any additional converter settings from kwargs
213
+ self._apply_converter_kwargs(converter)
214
+
215
+ tflite_model = converter.convert()
216
+
217
+ return tflite_model
218
+
219
+ except Exception as e:
220
+ # If direct conversion fails, raise the error with helpful message
221
+ raise RuntimeError(
222
+ f"Direct TFLite conversion failed. This may be due to model "
223
+ f"complexity or unsupported operations. Error: {e}"
224
+ ) from e
225
+
226
+ def _apply_converter_kwargs(self, converter):
227
+ """Apply additional converter settings from kwargs.
228
+
229
+ Args:
230
+ converter: tf.lite.TFLiteConverter instance to configure
231
+
232
+ Raises:
233
+ ValueError: If any kwarg is not a valid converter attribute
234
+ """
235
+ for attr, value in self.kwargs.items():
236
+ if attr == "target_spec" and isinstance(value, dict):
237
+ # Handle nested target_spec settings
238
+ for spec_key, spec_value in value.items():
239
+ if hasattr(converter.target_spec, spec_key):
240
+ setattr(converter.target_spec, spec_key, spec_value)
241
+ else:
242
+ raise ValueError(
243
+ f"Unknown target_spec attribute '{spec_key}'"
244
+ )
245
+ elif hasattr(converter, attr):
246
+ setattr(converter, attr, value)
247
+ else:
248
+ raise ValueError(f"Unknown converter attribute '{attr}'")
keras/src/export/onnx.py CHANGED
@@ -80,6 +80,10 @@ def export_onnx(
80
80
  "The model provided has never called. "
81
81
  "It must be called at least once before export."
82
82
  )
83
+ input_names = [
84
+ getattr(spec, "name", None) or f"input_{i}"
85
+ for i, spec in enumerate(input_signature)
86
+ ]
83
87
 
84
88
  if backend.backend() in ("tensorflow", "jax"):
85
89
  from keras.src.utils.module_utils import tf2onnx
@@ -143,6 +147,7 @@ def export_onnx(
143
147
  sample_inputs,
144
148
  verbose=actual_verbose,
145
149
  opset_version=opset_version,
150
+ input_names=input_names,
146
151
  dynamo=True,
147
152
  )
148
153
  if hasattr(onnx_program, "optimize"):
@@ -161,6 +166,7 @@ def export_onnx(
161
166
  filepath,
162
167
  verbose=actual_verbose,
163
168
  opset_version=opset_version,
169
+ input_names=input_names,
164
170
  )
165
171
  else:
166
172
  raise NotImplementedError(
@@ -55,7 +55,7 @@ def export_openvino(
55
55
  )
56
56
 
57
57
  import openvino as ov
58
- from openvino.runtime import opset14 as ov_opset
58
+ import openvino.opset14 as ov_opset
59
59
 
60
60
  from keras.src.backend.openvino.core import OPENVINO_DTYPES
61
61
  from keras.src.backend.openvino.core import OpenVINOKerasTensor
@@ -17,6 +17,9 @@ def patch_tf2onnx():
17
17
 
18
18
  logger = logging.getLogger(tf2onnx.__name__)
19
19
 
20
+ if not hasattr(np, "object"):
21
+ np.object = object
22
+
20
23
  def patched_rewrite_constant_fold(g, ops):
21
24
  """
22
25
  We call tensorflow transform with constant folding but in some cases
@@ -29,6 +29,7 @@ from keras.src.layers.core.input_layer import Input
29
29
  from keras.src.layers.core.input_layer import InputLayer
30
30
  from keras.src.layers.core.lambda_layer import Lambda
31
31
  from keras.src.layers.core.masking import Masking
32
+ from keras.src.layers.core.reversible_embedding import ReversibleEmbedding
32
33
  from keras.src.layers.core.wrapper import Wrapper
33
34
  from keras.src.layers.input_spec import InputSpec
34
35
  from keras.src.layers.layer import Layer
@@ -62,6 +63,18 @@ from keras.src.layers.normalization.spectral_normalization import (
62
63
  SpectralNormalization,
63
64
  )
64
65
  from keras.src.layers.normalization.unit_normalization import UnitNormalization
66
+ from keras.src.layers.pooling.adaptive_average_pooling1d import (
67
+ AdaptiveAveragePooling1D,
68
+ )
69
+ from keras.src.layers.pooling.adaptive_average_pooling2d import (
70
+ AdaptiveAveragePooling2D,
71
+ )
72
+ from keras.src.layers.pooling.adaptive_average_pooling3d import (
73
+ AdaptiveAveragePooling3D,
74
+ )
75
+ from keras.src.layers.pooling.adaptive_max_pooling1d import AdaptiveMaxPooling1D
76
+ from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D
77
+ from keras.src.layers.pooling.adaptive_max_pooling3d import AdaptiveMaxPooling3D
65
78
  from keras.src.layers.pooling.average_pooling1d import AveragePooling1D
66
79
  from keras.src.layers.pooling.average_pooling2d import AveragePooling2D
67
80
  from keras.src.layers.pooling.average_pooling3d import AveragePooling3D
@@ -52,10 +52,15 @@ class Softmax(Layer):
52
52
 
53
53
  def call(self, inputs, mask=None):
54
54
  if mask is not None:
55
- adder = (
56
- 1.0 - backend.cast(mask, inputs.dtype)
57
- ) * _large_negative_number(inputs.dtype)
58
- inputs += adder
55
+ # We keep the positions where the mask is True or > 0.5, and set the
56
+ # other (masked) positions to -1e.9.
57
+ if backend.standardize_dtype(mask.dtype) != "bool":
58
+ mask = backend.numpy.greater(
59
+ mask, backend.cast(0.5, dtype=mask.dtype)
60
+ )
61
+ inputs = backend.numpy.where(
62
+ mask, inputs, _large_negative_number(inputs.dtype)
63
+ )
59
64
  if isinstance(self.axis, (tuple, list)):
60
65
  if len(self.axis) > 1:
61
66
  outputs = backend.numpy.exp(
@@ -121,7 +121,7 @@ class Attention(Layer):
121
121
  if self.score_mode == "dot":
122
122
  scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1]))
123
123
  if self.scale is not None:
124
- scores *= self.scale
124
+ scores = ops.multiply(scores, self.scale)
125
125
  elif self.score_mode == "concat":
126
126
  # Reshape tensors to enable broadcasting.
127
127
  # Reshape into [batch_size, Tq, 1, dim].
@@ -378,7 +378,10 @@ class MultiHeadAttention(Layer):
378
378
  if self._attention_axes is None:
379
379
  self._attention_axes = tuple(range(1, rank - 2))
380
380
  else:
381
- self._attention_axes = tuple(self._attention_axes)
381
+ self._attention_axes = tuple(
382
+ axis if axis >= 0 else (rank - 1) + axis
383
+ for axis in self._attention_axes
384
+ )
382
385
  (
383
386
  self._dot_product_equation,
384
387
  self._combine_equation,