keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@ from keras.src import ops
|
|
|
3
3
|
from keras.src.api_export import keras_export
|
|
4
4
|
from keras.src.backend.common import global_state
|
|
5
5
|
|
|
6
|
-
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
|
|
6
|
+
QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq", "awq")
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
@keras_export(
|
|
@@ -376,6 +376,93 @@ class GPTQDTypePolicy(QuantizedDTypePolicy):
|
|
|
376
376
|
return config
|
|
377
377
|
|
|
378
378
|
|
|
379
|
+
@keras_export("keras.dtype_policies.AWQDTypePolicy")
|
|
380
|
+
class AWQDTypePolicy(QuantizedDTypePolicy):
|
|
381
|
+
"""Quantized dtype policy for AWQ quantization.
|
|
382
|
+
|
|
383
|
+
This policy helps propagate quantization settings for AWQ
|
|
384
|
+
when loading an AWQ quantized model in Keras format.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
mode: The quantization mode. This should be a string in the format
|
|
388
|
+
`"awq/<weight_bits>/<group_size>"`.
|
|
389
|
+
- `"awq"`: The identifier for the quantization algorithm.
|
|
390
|
+
- `<weight_bits>`: Number of bits to quantize weights to.
|
|
391
|
+
AWQ presently only supports 4-bit quantization.
|
|
392
|
+
- `<group_size>`: The group size for quantization. Supported
|
|
393
|
+
values are -1 (for per-channel quantization) or any
|
|
394
|
+
positive integer.
|
|
395
|
+
Example: `"awq/4/128"`.
|
|
396
|
+
source_name: The source dtype policy name, e.g. "float32".
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
def __init__(
|
|
400
|
+
self,
|
|
401
|
+
mode,
|
|
402
|
+
source_name=None,
|
|
403
|
+
):
|
|
404
|
+
parts = mode.split("/")
|
|
405
|
+
expected_format = "'awq/<weight_bits>/<group_size>'"
|
|
406
|
+
|
|
407
|
+
# Validate format.
|
|
408
|
+
if len(parts) != 3 or parts[0] != "awq":
|
|
409
|
+
raise ValueError(
|
|
410
|
+
"Invalid mode for AWQDTypePolicy. Expected format "
|
|
411
|
+
f"{expected_format}, but got '{mode}'."
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
# Validate and cast weight_bits and group_size.
|
|
415
|
+
try:
|
|
416
|
+
weight_bits = int(parts[1])
|
|
417
|
+
group_size = int(parts[2])
|
|
418
|
+
except ValueError:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
"Invalid mode for AWQDTypePolicy. <weight_bits> and "
|
|
421
|
+
"<group_size> must be integers. Expected format "
|
|
422
|
+
f"{expected_format}, but got '{mode}'."
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# AWQ presently only supports 4-bit quantization.
|
|
426
|
+
if weight_bits != 4:
|
|
427
|
+
raise ValueError(
|
|
428
|
+
"Invalid weight_bits in mode. AWQ only supports 4-bit "
|
|
429
|
+
f"quantization, but got {weight_bits} from '{mode}'."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
if group_size < -1 or group_size == 0:
|
|
433
|
+
raise ValueError(
|
|
434
|
+
"Invalid group_size in mode. Supported values are "
|
|
435
|
+
"-1 (per-channel) or a positive integer, "
|
|
436
|
+
f"but got {group_size} from '{mode}'."
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
base_mode = parts[0]
|
|
440
|
+
super().__init__(
|
|
441
|
+
mode=base_mode,
|
|
442
|
+
source_name=source_name,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
self._name = f"{mode}_from_{source_name}"
|
|
446
|
+
self.mode = base_mode
|
|
447
|
+
self.weight_bits = weight_bits
|
|
448
|
+
self.group_size = group_size
|
|
449
|
+
|
|
450
|
+
def __eq__(self, other):
|
|
451
|
+
if super().__eq__(other) is False:
|
|
452
|
+
return False
|
|
453
|
+
return (
|
|
454
|
+
self.weight_bits == other.weight_bits
|
|
455
|
+
and self.group_size == other.group_size
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def get_config(self):
|
|
459
|
+
config = super().get_config()
|
|
460
|
+
# Reconstruct the full mode string for serialization
|
|
461
|
+
mode = f"{self.mode}/{self.weight_bits}/{self.group_size}"
|
|
462
|
+
config.update({"mode": mode})
|
|
463
|
+
return config
|
|
464
|
+
|
|
465
|
+
|
|
379
466
|
@keras_export(
|
|
380
467
|
[
|
|
381
468
|
"keras.config.set_dtype_policy",
|
|
@@ -442,6 +529,8 @@ def _get_quantized_dtype_policy_by_str(policy):
|
|
|
442
529
|
return QuantizedDTypePolicy(mode, source_name)
|
|
443
530
|
elif policy.startswith("gptq"):
|
|
444
531
|
return GPTQDTypePolicy(mode, source_name)
|
|
532
|
+
elif policy.startswith("awq"):
|
|
533
|
+
return AWQDTypePolicy(mode, source_name)
|
|
445
534
|
elif policy.startswith("float8"):
|
|
446
535
|
return QuantizedFloat8DTypePolicy(mode, source_name)
|
|
447
536
|
else:
|
keras/src/export/__init__.py
CHANGED
keras/src/export/export_utils.py
CHANGED
|
@@ -7,6 +7,14 @@ from keras.src.utils.module_utils import tensorflow as tf
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def get_input_signature(model):
|
|
10
|
+
"""Get input signature for model export.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
model: A Keras Model instance.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Input signature suitable for model export (always a tuple or list).
|
|
17
|
+
"""
|
|
10
18
|
if not isinstance(model, models.Model):
|
|
11
19
|
raise TypeError(
|
|
12
20
|
"The model must be a `keras.Model`. "
|
|
@@ -17,13 +25,20 @@ def get_input_signature(model):
|
|
|
17
25
|
"The model provided has not yet been built. It must be built "
|
|
18
26
|
"before export."
|
|
19
27
|
)
|
|
28
|
+
|
|
20
29
|
if isinstance(model, models.Functional):
|
|
30
|
+
# Functional models expect a single positional argument `inputs`
|
|
31
|
+
# containing the full nested input structure. We keep the
|
|
32
|
+
# original behavior of returning a single-element list that
|
|
33
|
+
# wraps the mapped structure so that downstream exporters
|
|
34
|
+
# build a tf.function with one positional argument.
|
|
21
35
|
input_signature = [
|
|
22
36
|
tree.map_structure(make_input_spec, model._inputs_struct)
|
|
23
37
|
]
|
|
24
38
|
elif isinstance(model, models.Sequential):
|
|
25
39
|
input_signature = tree.map_structure(make_input_spec, model.inputs)
|
|
26
40
|
else:
|
|
41
|
+
# Subclassed models: rely on recorded shapes from the first call.
|
|
27
42
|
input_signature = _infer_input_signature_from_model(model)
|
|
28
43
|
if not input_signature or not model._called:
|
|
29
44
|
raise ValueError(
|
|
@@ -60,6 +75,7 @@ def _infer_input_signature_from_model(model):
|
|
|
60
75
|
f"Unsupported type {type(structure)} for {structure}"
|
|
61
76
|
)
|
|
62
77
|
|
|
78
|
+
# Always return a flat list preserving the order of shapes_dict values
|
|
63
79
|
return [_make_input_spec(value) for value in shapes_dict.values()]
|
|
64
80
|
|
|
65
81
|
|
|
@@ -86,13 +102,34 @@ def make_input_spec(x):
|
|
|
86
102
|
return input_spec
|
|
87
103
|
|
|
88
104
|
|
|
89
|
-
def make_tf_tensor_spec(x):
|
|
105
|
+
def make_tf_tensor_spec(x, dynamic_batch=False):
|
|
106
|
+
"""Create a TensorSpec from various input types.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
x: Input to convert (tf.TensorSpec, KerasTensor, or backend tensor).
|
|
110
|
+
dynamic_batch: If True, set the batch dimension to None.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A tf.TensorSpec instance.
|
|
114
|
+
"""
|
|
90
115
|
if isinstance(x, tf.TensorSpec):
|
|
91
116
|
tensor_spec = x
|
|
117
|
+
# Adjust batch dimension if needed
|
|
118
|
+
if dynamic_batch and len(tensor_spec.shape) > 0:
|
|
119
|
+
shape = tuple(
|
|
120
|
+
None if i == 0 else s for i, s in enumerate(tensor_spec.shape)
|
|
121
|
+
)
|
|
122
|
+
tensor_spec = tf.TensorSpec(
|
|
123
|
+
shape, dtype=tensor_spec.dtype, name=tensor_spec.name
|
|
124
|
+
)
|
|
92
125
|
else:
|
|
93
126
|
input_spec = make_input_spec(x)
|
|
127
|
+
shape = input_spec.shape
|
|
128
|
+
# Adjust batch dimension if needed and shape is not None
|
|
129
|
+
if dynamic_batch and shape is not None and len(shape) > 0:
|
|
130
|
+
shape = tuple(None if i == 0 else s for i, s in enumerate(shape))
|
|
94
131
|
tensor_spec = tf.TensorSpec(
|
|
95
|
-
|
|
132
|
+
shape, dtype=input_spec.dtype, name=input_spec.name
|
|
96
133
|
)
|
|
97
134
|
return tensor_spec
|
|
98
135
|
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
from keras.src import layers
|
|
2
|
+
from keras.src import models
|
|
3
|
+
from keras.src import tree
|
|
4
|
+
from keras.src.export.export_utils import get_input_signature
|
|
5
|
+
from keras.src.utils import io_utils
|
|
6
|
+
from keras.src.utils.module_utils import tensorflow as tf
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def export_litert(
|
|
10
|
+
model,
|
|
11
|
+
filepath,
|
|
12
|
+
input_signature=None,
|
|
13
|
+
**kwargs,
|
|
14
|
+
):
|
|
15
|
+
"""Export the model as a LiteRT artifact for inference.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
model: The Keras model to export.
|
|
19
|
+
filepath: The path to save the exported artifact.
|
|
20
|
+
input_signature: Optional input signature specification. If
|
|
21
|
+
`None`, it will be inferred.
|
|
22
|
+
**kwargs: Additional keyword arguments passed to the exporter.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
exporter = LiteRTExporter(
|
|
26
|
+
model=model,
|
|
27
|
+
input_signature=input_signature,
|
|
28
|
+
**kwargs,
|
|
29
|
+
)
|
|
30
|
+
exporter.export(filepath)
|
|
31
|
+
io_utils.print_msg(f"Saved artifact at '{filepath}'.")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class LiteRTExporter:
|
|
35
|
+
"""Exporter for the LiteRT (TFLite) format.
|
|
36
|
+
|
|
37
|
+
This class handles the conversion of Keras models for LiteRT runtime and
|
|
38
|
+
generates a `.tflite` model file. For efficient inference on mobile and
|
|
39
|
+
embedded devices, it creates a single callable signature based on the
|
|
40
|
+
model's `call()` method.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
model,
|
|
46
|
+
input_signature=None,
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
49
|
+
"""Initialize the LiteRT exporter.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
model: The Keras model to export
|
|
53
|
+
input_signature: Input signature specification (e.g., TensorFlow
|
|
54
|
+
TensorSpec or list of TensorSpec)
|
|
55
|
+
**kwargs: Additional export parameters
|
|
56
|
+
"""
|
|
57
|
+
self.model = model
|
|
58
|
+
self.input_signature = input_signature
|
|
59
|
+
self.kwargs = kwargs
|
|
60
|
+
|
|
61
|
+
def export(self, filepath):
|
|
62
|
+
"""Exports the Keras model to a TFLite file.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
filepath: Output path for the exported model
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Path to exported model
|
|
69
|
+
"""
|
|
70
|
+
# 1. Resolve / infer input signature
|
|
71
|
+
if self.input_signature is None:
|
|
72
|
+
# Use the standard get_input_signature which handles all model types
|
|
73
|
+
# and preserves nested structures (dicts, lists, etc.)
|
|
74
|
+
self.input_signature = get_input_signature(self.model)
|
|
75
|
+
|
|
76
|
+
# 2. Determine input structure and create adapter if needed
|
|
77
|
+
# There are 3 cases:
|
|
78
|
+
# Case 1: Single input (not nested)
|
|
79
|
+
# Case 2: Flat list of inputs (list where flattened == original)
|
|
80
|
+
# Case 3: Nested structure (dicts, nested lists, etc.)
|
|
81
|
+
|
|
82
|
+
# Special handling for Functional models: get_input_signature wraps
|
|
83
|
+
# the structure in a list, so unwrap it for analysis
|
|
84
|
+
input_struct = self.input_signature
|
|
85
|
+
if (
|
|
86
|
+
isinstance(self.input_signature, list)
|
|
87
|
+
and len(self.input_signature) == 1
|
|
88
|
+
):
|
|
89
|
+
input_struct = self.input_signature[0]
|
|
90
|
+
|
|
91
|
+
if not tree.is_nested(input_struct):
|
|
92
|
+
# Case 1: Single input - use as-is
|
|
93
|
+
model_to_convert = self.model
|
|
94
|
+
signature_for_conversion = self.input_signature
|
|
95
|
+
elif isinstance(input_struct, list) and len(input_struct) == len(
|
|
96
|
+
tree.flatten(input_struct)
|
|
97
|
+
):
|
|
98
|
+
# Case 2: Flat list of inputs - use as-is
|
|
99
|
+
model_to_convert = self.model
|
|
100
|
+
signature_for_conversion = self.input_signature
|
|
101
|
+
else:
|
|
102
|
+
# Case 3: Nested structure (dict, nested lists, etc.)
|
|
103
|
+
# Create adapter model that converts flat list to nested structure
|
|
104
|
+
adapted_model = self._create_nested_inputs_adapter(input_struct)
|
|
105
|
+
|
|
106
|
+
# Flatten signature for TFLite conversion
|
|
107
|
+
signature_for_conversion = tree.flatten(input_struct)
|
|
108
|
+
|
|
109
|
+
# Use adapted model and flat list signature for conversion
|
|
110
|
+
model_to_convert = adapted_model
|
|
111
|
+
|
|
112
|
+
# Store original model reference for later use
|
|
113
|
+
original_model = self.model
|
|
114
|
+
|
|
115
|
+
# Temporarily replace self.model with the model to convert
|
|
116
|
+
self.model = model_to_convert
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
# Convert the model to TFLite.
|
|
120
|
+
tflite_model = self._convert_to_tflite(signature_for_conversion)
|
|
121
|
+
finally:
|
|
122
|
+
# Restore original model
|
|
123
|
+
self.model = original_model
|
|
124
|
+
|
|
125
|
+
# Save the TFLite model to the specified file path.
|
|
126
|
+
if not filepath.endswith(".tflite"):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"The LiteRT export requires the filepath to end with "
|
|
129
|
+
f"'.tflite'. Got: {filepath}"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
with open(filepath, "wb") as f:
|
|
133
|
+
f.write(tflite_model)
|
|
134
|
+
|
|
135
|
+
return filepath
|
|
136
|
+
|
|
137
|
+
def _create_nested_inputs_adapter(self, input_signature_struct):
|
|
138
|
+
"""Create an adapter model that converts flat list inputs to nested
|
|
139
|
+
structure.
|
|
140
|
+
|
|
141
|
+
This adapter allows models expecting nested inputs (dicts, lists, etc.)
|
|
142
|
+
to be exported to TFLite format (which only supports positional/list
|
|
143
|
+
inputs).
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
input_signature_struct: Nested structure of InputSpecs (dict, list,
|
|
147
|
+
etc.)
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
A Functional model that accepts flat list inputs and converts to
|
|
151
|
+
nested
|
|
152
|
+
"""
|
|
153
|
+
# Get flat paths to preserve names and print input mapping
|
|
154
|
+
paths_and_specs = tree.flatten_with_path(input_signature_struct)
|
|
155
|
+
paths = [".".join(str(e) for e in p) for p, v in paths_and_specs]
|
|
156
|
+
io_utils.print_msg(f"Creating adapter for inputs: {paths}")
|
|
157
|
+
|
|
158
|
+
# Create Input layers for TFLite (flat list-based)
|
|
159
|
+
input_layers = []
|
|
160
|
+
for path, spec in paths_and_specs:
|
|
161
|
+
# Extract the input name from spec or path
|
|
162
|
+
name = (
|
|
163
|
+
spec.name
|
|
164
|
+
if hasattr(spec, "name") and spec.name
|
|
165
|
+
else (str(path[-1]) if path else "input")
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
input_layer = layers.Input(
|
|
169
|
+
shape=spec.shape[1:], # Remove batch dimension
|
|
170
|
+
dtype=spec.dtype,
|
|
171
|
+
name=name,
|
|
172
|
+
)
|
|
173
|
+
input_layers.append(input_layer)
|
|
174
|
+
|
|
175
|
+
# Reconstruct the nested structure from flat list
|
|
176
|
+
inputs_structure = tree.pack_sequence_as(
|
|
177
|
+
input_signature_struct, input_layers
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Call the original model with nested inputs
|
|
181
|
+
outputs = self.model(inputs_structure)
|
|
182
|
+
|
|
183
|
+
# Build as Functional model (flat list inputs -> nested -> model ->
|
|
184
|
+
# output)
|
|
185
|
+
adapted_model = models.Model(inputs=input_layers, outputs=outputs)
|
|
186
|
+
|
|
187
|
+
# Preserve the original model's variables
|
|
188
|
+
adapted_model._variables = self.model.variables
|
|
189
|
+
adapted_model._trainable_variables = self.model.trainable_variables
|
|
190
|
+
adapted_model._non_trainable_variables = (
|
|
191
|
+
self.model.non_trainable_variables
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return adapted_model
|
|
195
|
+
|
|
196
|
+
def _convert_to_tflite(self, input_signature):
|
|
197
|
+
"""Converts the Keras model to TFLite format.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
A bytes object containing the serialized TFLite model.
|
|
201
|
+
"""
|
|
202
|
+
# Try direct conversion first for all models
|
|
203
|
+
try:
|
|
204
|
+
converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
|
|
205
|
+
converter.target_spec.supported_ops = [
|
|
206
|
+
tf.lite.OpsSet.TFLITE_BUILTINS,
|
|
207
|
+
tf.lite.OpsSet.SELECT_TF_OPS,
|
|
208
|
+
]
|
|
209
|
+
# Keras 3 only supports resource variables
|
|
210
|
+
converter.experimental_enable_resource_variables = True
|
|
211
|
+
|
|
212
|
+
# Apply any additional converter settings from kwargs
|
|
213
|
+
self._apply_converter_kwargs(converter)
|
|
214
|
+
|
|
215
|
+
tflite_model = converter.convert()
|
|
216
|
+
|
|
217
|
+
return tflite_model
|
|
218
|
+
|
|
219
|
+
except Exception as e:
|
|
220
|
+
# If direct conversion fails, raise the error with helpful message
|
|
221
|
+
raise RuntimeError(
|
|
222
|
+
f"Direct TFLite conversion failed. This may be due to model "
|
|
223
|
+
f"complexity or unsupported operations. Error: {e}"
|
|
224
|
+
) from e
|
|
225
|
+
|
|
226
|
+
def _apply_converter_kwargs(self, converter):
|
|
227
|
+
"""Apply additional converter settings from kwargs.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
converter: tf.lite.TFLiteConverter instance to configure
|
|
231
|
+
|
|
232
|
+
Raises:
|
|
233
|
+
ValueError: If any kwarg is not a valid converter attribute
|
|
234
|
+
"""
|
|
235
|
+
for attr, value in self.kwargs.items():
|
|
236
|
+
if attr == "target_spec" and isinstance(value, dict):
|
|
237
|
+
# Handle nested target_spec settings
|
|
238
|
+
for spec_key, spec_value in value.items():
|
|
239
|
+
if hasattr(converter.target_spec, spec_key):
|
|
240
|
+
setattr(converter.target_spec, spec_key, spec_value)
|
|
241
|
+
else:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Unknown target_spec attribute '{spec_key}'"
|
|
244
|
+
)
|
|
245
|
+
elif hasattr(converter, attr):
|
|
246
|
+
setattr(converter, attr, value)
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError(f"Unknown converter attribute '{attr}'")
|
keras/src/export/openvino.py
CHANGED
|
@@ -55,7 +55,7 @@ def export_openvino(
|
|
|
55
55
|
)
|
|
56
56
|
|
|
57
57
|
import openvino as ov
|
|
58
|
-
|
|
58
|
+
import openvino.opset14 as ov_opset
|
|
59
59
|
|
|
60
60
|
from keras.src.backend.openvino.core import OPENVINO_DTYPES
|
|
61
61
|
from keras.src.backend.openvino.core import OpenVINOKerasTensor
|
keras/src/export/tf2onnx_lib.py
CHANGED
keras/src/layers/__init__.py
CHANGED
|
@@ -29,6 +29,7 @@ from keras.src.layers.core.input_layer import Input
|
|
|
29
29
|
from keras.src.layers.core.input_layer import InputLayer
|
|
30
30
|
from keras.src.layers.core.lambda_layer import Lambda
|
|
31
31
|
from keras.src.layers.core.masking import Masking
|
|
32
|
+
from keras.src.layers.core.reversible_embedding import ReversibleEmbedding
|
|
32
33
|
from keras.src.layers.core.wrapper import Wrapper
|
|
33
34
|
from keras.src.layers.input_spec import InputSpec
|
|
34
35
|
from keras.src.layers.layer import Layer
|
|
@@ -62,6 +63,18 @@ from keras.src.layers.normalization.spectral_normalization import (
|
|
|
62
63
|
SpectralNormalization,
|
|
63
64
|
)
|
|
64
65
|
from keras.src.layers.normalization.unit_normalization import UnitNormalization
|
|
66
|
+
from keras.src.layers.pooling.adaptive_average_pooling1d import (
|
|
67
|
+
AdaptiveAveragePooling1D,
|
|
68
|
+
)
|
|
69
|
+
from keras.src.layers.pooling.adaptive_average_pooling2d import (
|
|
70
|
+
AdaptiveAveragePooling2D,
|
|
71
|
+
)
|
|
72
|
+
from keras.src.layers.pooling.adaptive_average_pooling3d import (
|
|
73
|
+
AdaptiveAveragePooling3D,
|
|
74
|
+
)
|
|
75
|
+
from keras.src.layers.pooling.adaptive_max_pooling1d import AdaptiveMaxPooling1D
|
|
76
|
+
from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D
|
|
77
|
+
from keras.src.layers.pooling.adaptive_max_pooling3d import AdaptiveMaxPooling3D
|
|
65
78
|
from keras.src.layers.pooling.average_pooling1d import AveragePooling1D
|
|
66
79
|
from keras.src.layers.pooling.average_pooling2d import AveragePooling2D
|
|
67
80
|
from keras.src.layers.pooling.average_pooling3d import AveragePooling3D
|
|
@@ -52,10 +52,15 @@ class Softmax(Layer):
|
|
|
52
52
|
|
|
53
53
|
def call(self, inputs, mask=None):
|
|
54
54
|
if mask is not None:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
55
|
+
# We keep the positions where the mask is True or > 0.5, and set the
|
|
56
|
+
# other (masked) positions to -1e.9.
|
|
57
|
+
if backend.standardize_dtype(mask.dtype) != "bool":
|
|
58
|
+
mask = backend.numpy.greater(
|
|
59
|
+
mask, backend.cast(0.5, dtype=mask.dtype)
|
|
60
|
+
)
|
|
61
|
+
inputs = backend.numpy.where(
|
|
62
|
+
mask, inputs, _large_negative_number(inputs.dtype)
|
|
63
|
+
)
|
|
59
64
|
if isinstance(self.axis, (tuple, list)):
|
|
60
65
|
if len(self.axis) > 1:
|
|
61
66
|
outputs = backend.numpy.exp(
|
|
@@ -378,7 +378,10 @@ class MultiHeadAttention(Layer):
|
|
|
378
378
|
if self._attention_axes is None:
|
|
379
379
|
self._attention_axes = tuple(range(1, rank - 2))
|
|
380
380
|
else:
|
|
381
|
-
self._attention_axes = tuple(
|
|
381
|
+
self._attention_axes = tuple(
|
|
382
|
+
axis if axis >= 0 else (rank - 1) + axis
|
|
383
|
+
for axis in self._attention_axes
|
|
384
|
+
)
|
|
382
385
|
(
|
|
383
386
|
self._dot_product_equation,
|
|
384
387
|
self._combine_equation,
|